From 1f6744162f606cce895f2d9818207ddecbce5932 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 28 Feb 2026 23:49:12 +0200 Subject: [PATCH 01/75] feat: Support SCAIL WanVideo model (#12614) --- comfy/ldm/wan/model.py | 115 ++++++++++++++++++++++++++++++++++++++ comfy/model_base.py | 38 +++++++++++++ comfy/model_detection.py | 2 + comfy/supported_models.py | 12 +++- comfy_extras/nodes_wan.py | 58 +++++++++++++++++++ node_helpers.py | 31 ++++++++++ 6 files changed, 255 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index ea123acb4..b2287dba9 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + +class SCAILWanModel(WanModel): + def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs): + super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) + + self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) + + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs): + + if reference_latent is not None: + x = torch.cat((reference_latent, x), dim=2) + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes + x = x.flatten(2).transpose(1, 2) + + scail_pose_seq_len = 0 + if pose_latents is not None: + scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) + scail_x = scail_x.flatten(2).transpose(1, 2) + scail_pose_seq_len = scail_x.shape[1] + x = torch.cat([x, scail_x], dim=1) + del scail_x + + # time embeddings + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.cat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + if scail_pose_seq_len > 0: + x = x[:, :-scail_pose_seq_len] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + + if reference_latent is not None: + x = x[:, :, reference_latent.shape[2]:] + + return x + + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}): + main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) + + if pose_latents is None: + return main_freqs + + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] + + # if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames + h_scale = h / H_pose + w_scale = w / W_pose + + # 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code + h_shift = (h_scale - 1) / 2 + w_shift = (w_scale - 1) / 2 + pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}} + pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options) + + return torch.cat([main_freqs, pose_freqs], dim=1) + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if pose_latents is not None: + pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size) + + t_len = t + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = x.shape[2] + + reference_latent = None + if "reference_latent" in kwargs: + reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) + t_len += reference_latent.shape[2] + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] diff --git a/comfy/model_base.py b/comfy/model_base.py index 85cd30bae..a1c690b9b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1502,6 +1502,44 @@ class WAN21_FlowRVS(WAN21): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) self.image_to_video = image_to_video +class WAN21_SCAIL(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel) + self.memory_usage_factor_conds = ("reference_latent", "pose_latents") + self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + ref_latent = self.process_latent_in(reference_latents[-1]) + ref_mask = torch.ones_like(ref_latent[:, :4]) + ref_latent = torch.cat([ref_latent, ref_mask], dim=1) + out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + pose_latents = self.process_latent_in(pose_latents) + pose_mask = torch.ones_like(pose_latents[:, :4]) + pose_latents = torch.cat([pose_latents, pose_mask], dim=1) + out['pose_latents'] = comfy.conds.CONDRegular(pose_latents) + + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] + + return out + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8a1d8ea4d..3faa950ca 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -498,6 +498,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "humo" elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "animate" + elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "scail" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 473fbbfd4..4f63e8327 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1268,6 +1268,16 @@ class WAN21_FlowRVS(WAN21_T2V): out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device) return out +class WAN21_SCAIL(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "scail", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1710,6 +1720,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index effa994d1..e50bfcd2c 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode): return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) +class WanSCAILToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSCAILToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("reference_image", optional=True), + io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), + io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), + io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."), + io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + ref_latent = None + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(reference_image[:, :, :, :3]) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength + positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension): WanAnimateToVideo, Wan22ImageToVideoLatent, WanInfiniteTalkToVideo, + WanSCAILToVideo, ] async def comfy_entrypoint() -> WanExtension: diff --git a/node_helpers.py b/node_helpers.py index 4ff960ef8..d3d834516 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -1,5 +1,6 @@ import hashlib import torch +import logging from comfy.cli_args import args @@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False): return c +def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0): + """ + Apply values to conditioning only during [start_percent, end_percent], keeping the + original conditioning active outside that range. Respects existing per-entry ranges. + """ + if start_percent > end_percent: + logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})") + return conditioning + + EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep + c = [] + for t in conditioning: + cond_start = t[1].get("start_percent", 0.0) + cond_end = t[1].get("end_percent", 1.0) + intersect_start = max(start_percent, cond_start) + intersect_end = min(end_percent, cond_end) + + if intersect_start >= intersect_end: # no overlap: emit unchanged + c.append(t) + continue + + if intersect_start > cond_start: # part before the requested range + c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS})) + + c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end})) + + if intersect_end < cond_end: # part after the requested range + c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end})) + return c + def pillow(fn, arg): prev_value = None try: From 5f41584e960d3ad90f6581278e57f7b52e771db4 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 28 Feb 2026 13:50:18 -0800 Subject: [PATCH 02/75] Disable dynamic_vram when weight hooks applied (#12653) * sd: add support for clip model reconstruction * nodes: SetClipHooks: Demote the dynamic model patcher * mp: Make dynamic_disable more robust The backup need to not be cloned. In addition add a delegate object to ModelPatcherDynamic so that non-cloning code can do ModelPatcherDynamic demotion * sampler_helpers: Demote to non-dynamic model patcher when hooking * code rabbit review comments --- comfy/model_patcher.py | 29 ++++++++++++++++++++-------- comfy/sampler_helpers.py | 12 ++++++++++++ comfy/samplers.py | 2 ++ comfy/sd.py | 38 +++++++++++++++++++++++++++---------- comfy_extras/nodes_hooks.py | 2 +- 5 files changed, 64 insertions(+), 19 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1c9ba8096..3fc76d9db 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -308,15 +308,22 @@ class ModelPatcher: def get_free_memory(self, device): return comfy.model_management.get_free_memory(device) - def clone(self, disable_dynamic=False): + def get_clone_model_override(self): + return self.model, (self.backup, self.object_patches_backup, self.pinned) + + def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ - model = self.model if self.is_dynamic() and disable_dynamic: class_ = ModelPatcher - temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) - model = temp_model_patcher.model + if model_override is None: + if self.cached_patcher_init is None: + raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") + temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + model_override = temp_model_patcher.get_clone_model_override() + if model_override is None: + model_override = self.get_clone_model_override() - n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) + n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -325,13 +332,12 @@ class ModelPatcher: n.object_patches = self.object_patches.copy() n.weight_wrapper_patches = self.weight_wrapper_patches.copy() n.model_options = comfy.utils.deepcopy_list_dict(self.model_options) - n.backup = self.backup - n.object_patches_backup = self.object_patches_backup n.parent = self - n.pinned = self.pinned n.force_cast_weights = self.force_cast_weights + n.backup, n.object_patches_backup, n.pinned = model_override[1] + # attachments n.attachments = {} for k in self.attachments: @@ -1435,6 +1441,7 @@ class ModelPatcherDynamic(ModelPatcher): del self.model.model_loaded_weight_memory if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} + self.non_dynamic_delegate_model = None assert load_device is not None def is_dynamic(self): @@ -1669,4 +1676,10 @@ class ModelPatcherDynamic(ModelPatcher): def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: pass + def get_non_dynamic_delegate(self): + model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model) + self.non_dynamic_delegate_model = model_patcher.get_clone_model_override() + return model_patcher + + CoreModelPatcher = ModelPatcher diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1f75f2ba7..bbba09e26 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -66,6 +66,18 @@ def convert_cond(cond): out.append(temp) return out +def cond_has_hooks(cond): + for c in cond: + temp = c[1] + if "hooks" in temp: + return True + if "control" in temp: + control = temp["control"] + extra_hooks = control.get_extra_hooks() + if len(extra_hooks) > 0: + return True + return False + def get_additional_models(conds, dtype): """loads additional models in conditioning""" cnets: list[ControlBase] = [] diff --git a/comfy/samplers.py b/comfy/samplers.py index 8b9782956..8be449ef7 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -946,6 +946,8 @@ class CFGGuider: def inner_set_conds(self, conds): for k in conds: + if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]): + self.model_patcher = self.model_patcher.get_non_dynamic_delegate() self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) def __call__(self, *args, **kwargs): diff --git a/comfy/sd.py b/comfy/sd.py index 7713d4678..a9ad7c2d2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -204,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False): if no_init: return params = target.params.copy() @@ -233,7 +233,8 @@ class CLIP: model_management.archive_model_dtypes(self.cond_stage_model) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -267,9 +268,9 @@ class CLIP: logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) self.tokenizer_options = {} - def clone(self): + def clone(self, disable_dynamic=False): n = CLIP(no_init=True) - n.patcher = self.patcher.clone() + n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic) n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx @@ -1164,14 +1165,21 @@ class CLIPType(Enum): LONGCAT_IMAGE = 26 -def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): + +def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): + clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic) + return clip.patcher + +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) if model_options.get("custom_operations", None) is None: sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) - return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) + clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic) + clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options)) + return clip class TEModel(Enum): @@ -1276,7 +1284,7 @@ def llama_detect(clip_data): return {} -def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): +def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): clip_data = state_dicts class EmptyClass: @@ -1496,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic) return clip def load_gligen(ckpt_path): @@ -1541,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) - if output_model: + if output_model and out[0] is not None: out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) + if output_clip and out[1] is not None: + out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options)) return out def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): @@ -1553,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, disable_dynamic=disable_dynamic) return model +def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + _, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False, + embedding_directory=embedding_directory, output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic) + return clip.patcher + def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False): clip = None clipvision = None @@ -1638,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic) else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index be7d600cd..056369e86 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -248,7 +248,7 @@ class SetClipHooks: def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): if hooks is not None: - clip = clip.clone() + clip = clip.clone(disable_dynamic=True) if apply_to_conds: clip.apply_hooks_to_conds = hooks clip.patcher.forced_hooks = hooks.clone() From 48bb0bd18aa90bba0eac7b4c1a1400c4f7110046 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 28 Feb 2026 13:52:30 -0800 Subject: [PATCH 03/75] cli_args: Default comfy to DynamicVram mode (#12658) --- comfy/cli_args.py | 4 ++-- main.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 63daca861..13079c7bc 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -146,6 +146,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") +parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") @@ -159,7 +160,6 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" - DynamicVRAM = "dynamic_vram" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) @@ -260,4 +260,4 @@ else: args.fast = set(args.fast) def enables_dynamic_vram(): - return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only + return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu diff --git a/main.py b/main.py index 3fe8f0589..a0545d9b3 100644 --- a/main.py +++ b/main.py @@ -192,7 +192,7 @@ import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher -if enables_dynamic_vram(): +if enables_dynamic_vram() and comfy.model_management.is_nvidia(): if comfy.model_management.torch_version_numeric < (2, 8): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): From 17106cb124fcfa0b75ea24993c65aa024059fc8d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 28 Feb 2026 19:21:32 -0800 Subject: [PATCH 04/75] Move parsing of requirements logic to function. (#12701) --- app/frontend_management.py | 42 ++------------------ tests-unit/app_test/frontend_manager_test.py | 6 +++ utils/install_util.py | 33 +++++++++++++++ 3 files changed, 42 insertions(+), 39 deletions(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index bdaa85812..f753ef0de 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -17,7 +17,7 @@ from importlib.metadata import version import requests from typing_extensions import NotRequired -from utils.install_util import get_missing_requirements_message, requirements_path +from utils.install_util import get_missing_requirements_message, get_required_packages_versions from comfy.cli_args import DEFAULT_VERSION_STRING import app.logger @@ -45,25 +45,7 @@ def get_installed_frontend_version(): def get_required_frontend_version(): - """Get the required frontend version from requirements.txt.""" - try: - with open(requirements_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line.startswith("comfyui-frontend-package=="): - version_str = line.split("==")[-1] - if not is_valid_version(version_str): - logging.error(f"Invalid version format in requirements.txt: {version_str}") - return None - return version_str - logging.error("comfyui-frontend-package not found in requirements.txt") - return None - except FileNotFoundError: - logging.error("requirements.txt not found. Cannot determine required frontend version.") - return None - except Exception as e: - logging.error(f"Error reading requirements.txt: {e}") - return None + return get_required_packages_versions().get("comfyui-frontend-package", None) def check_frontend_version(): @@ -217,25 +199,7 @@ class FrontendManager: @classmethod def get_required_templates_version(cls) -> str: - """Get the required workflow templates version from requirements.txt.""" - try: - with open(requirements_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line.startswith("comfyui-workflow-templates=="): - version_str = line.split("==")[-1] - if not is_valid_version(version_str): - logging.error(f"Invalid templates version format in requirements.txt: {version_str}") - return None - return version_str - logging.error("comfyui-workflow-templates not found in requirements.txt") - return None - except FileNotFoundError: - logging.error("requirements.txt not found. Cannot determine required templates version.") - return None - except Exception as e: - logging.error(f"Error reading requirements.txt: {e}") - return None + return get_required_packages_versions().get("comfyui-workflow-templates", None) @classmethod def default_frontend_path(cls) -> str: diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py index 643f04e72..1d5a84b47 100644 --- a/tests-unit/app_test/frontend_manager_test.py +++ b/tests-unit/app_test/frontend_manager_test.py @@ -49,6 +49,12 @@ def mock_provider(mock_releases): return provider +@pytest.fixture(autouse=True) +def clear_cache(): + import utils.install_util + utils.install_util.PACKAGE_VERSIONS = {} + + def test_get_release(mock_provider, mock_releases): version = "1.0.0" release = mock_provider.get_release(version) diff --git a/utils/install_util.py b/utils/install_util.py index 0f59bcf91..34489aec5 100644 --- a/utils/install_util.py +++ b/utils/install_util.py @@ -1,5 +1,7 @@ from pathlib import Path import sys +import logging +import re # The path to the requirements.txt file requirements_path = Path(__file__).parents[1] / "requirements.txt" @@ -16,3 +18,34 @@ Please install the updated requirements.txt file by running: {sys.executable} {extra}-m pip install -r {requirements_path} If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem. """.strip() + + +def is_valid_version(version: str) -> bool: + """Validate if a string is a valid semantic version (X.Y.Z format).""" + pattern = r"^(\d+)\.(\d+)\.(\d+)$" + return bool(re.match(pattern, version)) + + +PACKAGE_VERSIONS = {} +def get_required_packages_versions(): + if len(PACKAGE_VERSIONS) > 0: + return PACKAGE_VERSIONS.copy() + out = PACKAGE_VERSIONS + try: + with open(requirements_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip().replace(">=", "==") + s = line.split("==") + if len(s) == 2: + version_str = s[-1] + if not is_valid_version(version_str): + logging.error(f"Invalid version format in requirements.txt: {version_str}") + continue + out[s[0]] = version_str + return out.copy() + except FileNotFoundError: + logging.error("requirements.txt not found.") + return None + except Exception as e: + logging.error(f"Error reading requirements.txt: {e}") + return None From 1080bd442a7509d29bfe0b29cac9222de406c994 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 28 Feb 2026 19:23:28 -0800 Subject: [PATCH 05/75] Disable dynamic vram on wsl. (#12706) --- comfy/model_management.py | 8 ++++++++ main.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index f73613f17..86f840ada 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -180,6 +180,14 @@ def is_ixuca(): return True return False +def is_wsl(): + version = platform.uname().release + if version.endswith("-Microsoft"): + return True + elif version.endswith("microsoft-standard-WSL2"): + return True + return False + def get_torch_device(): global directml_enabled global cpu_state diff --git a/main.py b/main.py index a0545d9b3..af701f8df 100644 --- a/main.py +++ b/main.py @@ -192,7 +192,7 @@ import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher -if enables_dynamic_vram() and comfy.model_management.is_nvidia(): +if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl(): if comfy.model_management.torch_version_numeric < (2, 8): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): From d159142615e0a1a7ae4eb711a6ae9f66a5f2d76e Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 28 Feb 2026 20:59:24 -0800 Subject: [PATCH 06/75] refactor: rename Mahiro CFG to Similarity-Adaptive Guidance (#12172) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: rename Mahiro CFG to Similarity-Adaptive Guidance Rename the display name to better describe what the node does: adaptively blends guidance based on cosine similarity between positive and negative conditions. Amp-Thread-ID: https://ampcode.com/threads/T-019c0d36-8b43-745f-b7b2-e35b53f17fa1 Co-authored-by: Amp * feat: add search aliases for old mahiro name Amp-Thread-ID: https://ampcode.com/threads/T-019c0d36-8b43-745f-b7b2-e35b53f17fa1 * rename: Similarity-Adaptive Guidance → Positive-Biased Guidance (per reviewer) - display_name changed to 'Positive-Biased Guidance' to avoid SAG acronym collision - search_aliases expanded: mahiro, mahiro cfg, similarity-adaptive guidance, positive-biased cfg - ruff format applied --------- Co-authored-by: Amp Co-authored-by: Jedrzej Kosinski --- comfy_extras/nodes_mahiro.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 6459ca8c1..a25226e6d 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Mahiro", - display_name="Mahiro CFG", + display_name="Positive-Biased Guidance", category="_for_testing", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", inputs=[ @@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode): io.Model.Output(display_name="patched_model"), ], is_experimental=True, + search_aliases=[ + "mahiro", + "mahiro cfg", + "similarity-adaptive guidance", + "positive-biased cfg", + ], ) @classmethod def execute(cls, model) -> io.NodeOutput: m = model.clone() + def mahiro_normd(args): - scale: float = args['cond_scale'] - cond_p: torch.Tensor = args['cond_denoised'] - uncond_p: torch.Tensor = args['uncond_denoised'] - #naive leap + scale: float = args["cond_scale"] + cond_p: torch.Tensor = args["cond_denoised"] + uncond_p: torch.Tensor = args["uncond_denoised"] + # naive leap leap = cond_p * scale - #sim with uncond leap + # sim with uncond leap u_leap = uncond_p * scale cfg = args["denoised"] merge = (leap + cfg) / 2 normu = torch.sqrt(u_leap.abs()) * u_leap.sign() normm = torch.sqrt(merge.abs()) * merge.sign() sim = F.cosine_similarity(normu, normm).mean() - simsc = 2 * (sim+1) - wm = (simsc*cfg + (4-simsc)*leap) / 4 + simsc = 2 * (sim + 1) + wm = (simsc * cfg + (4 - simsc) * leap) / 4 return wm + m.set_model_sampler_post_cfg_function(mahiro_normd) return io.NodeOutput(m) From 850e8b42ff67cec295edb686c4b85dc7811f5e7f Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 28 Feb 2026 21:38:19 -0800 Subject: [PATCH 07/75] feat: add text preview support to jobs API (#12169) * feat: add text preview support to jobs API Amp-Thread-ID: https://ampcode.com/threads/T-019c0be0-9fc6-71ac-853a-7c7cc846b375 Co-authored-by: Amp * test: update tests to expect text as previewable media type Amp-Thread-ID: https://ampcode.com/threads/T-019c0be0-9fc6-71ac-853a-7c7cc846b375 --------- --- comfy_execution/jobs.py | 53 ++++++++++++++++++++++++++++++------ tests/execution/test_jobs.py | 6 ++-- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 370014fb6..fcd7ef735 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -20,7 +20,7 @@ class JobStatus: # Media types that can be previewed in the frontend -PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'}) +PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) # 3D file extensions for preview fallback (no dedicated media_type exists) THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'}) @@ -75,6 +75,23 @@ def normalize_outputs(outputs: dict) -> dict: normalized[node_id] = normalized_node return normalized +# Text preview truncation limit (1024 characters) to prevent preview_output bloat +TEXT_PREVIEW_MAX_LENGTH = 1024 + + +def _create_text_preview(value: str) -> dict: + """Create a text preview dict with optional truncation. + + Returns: + dict with 'content' and optionally 'truncated' flag + """ + if len(value) <= TEXT_PREVIEW_MAX_LENGTH: + return {'content': value} + return { + 'content': value[:TEXT_PREVIEW_MAX_LENGTH], + 'truncated': True + } + def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: """Extract create_time and workflow_id from extra_data. @@ -221,23 +238,43 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]: continue for item in items: - normalized = normalize_output_item(item) - if normalized is None: - continue + if not isinstance(item, dict): + # Handle text outputs (non-dict items like strings or tuples) + normalized = normalize_output_item(item) + if normalized is None: + # Not a 3D file string — check for text preview + if media_type == 'text': + count += 1 + if preview_output is None: + if isinstance(item, tuple): + text_value = item[0] if item else '' + else: + text_value = str(item) + text_preview = _create_text_preview(text_value) + enriched = { + **text_preview, + 'nodeId': node_id, + 'mediaType': media_type + } + if fallback_preview is None: + fallback_preview = enriched + continue + # normalize_output_item returned a dict (e.g. 3D file) + item = normalized count += 1 if preview_output is not None: continue - if isinstance(normalized, dict) and is_previewable(media_type, normalized): + if is_previewable(media_type, item): enriched = { - **normalized, + **item, 'nodeId': node_id, } - if 'mediaType' not in normalized: + if 'mediaType' not in item: enriched['mediaType'] = media_type - if normalized.get('type') == 'output': + if item.get('type') == 'output': preview_output = enriched elif fallback_preview is None: fallback_preview = enriched diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 83c36fe48..814af5c13 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -38,13 +38,13 @@ class TestIsPreviewable: """Unit tests for is_previewable()""" def test_previewable_media_types(self): - """Images, video, audio, 3d media types should be previewable.""" - for media_type in ['images', 'video', 'audio', '3d']: + """Images, video, audio, 3d, text media types should be previewable.""" + for media_type in ['images', 'video', 'audio', '3d', 'text']: assert is_previewable(media_type, {}) is True def test_non_previewable_media_types(self): """Other media types should not be previewable.""" - for media_type in ['latents', 'text', 'metadata', 'files']: + for media_type in ['latents', 'metadata', 'files']: assert is_previewable(media_type, {}) is False def test_3d_extensions_previewable(self): From 4d79f4f0280da6c0a0e37123b9c80f24e2403536 Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Sun, 1 Mar 2026 10:38:30 -0700 Subject: [PATCH 08/75] fix: handle substep sigmas in context window set_step (#12719) Multi-step samplers (eg. dpmpp_2s_ancestral) call the model at intermediate sigma values not present in the schedule. This caused set_step to crash with "No sample_sigmas matched current timestep" when context windows were enabled. The fix is to keep self._step from the last exact match when a substep sigma is encountered, since substeps are still logically part of their parent step and should use the same context windows. Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com> --- comfy/context_windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2f82d51da..b54f7f39a 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC): mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: - raise Exception("No sample_sigmas matched current timestep; something went wrong.") + return # substep from multi-step sampler: keep self._step from the last full step self._step = int(matches[0].item()) def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: From c0d472e5b9b256d9e802ecac703bb6a8ca5f9eb8 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Mar 2026 11:14:56 -0800 Subject: [PATCH 09/75] comfy-aimdo 0.2.3 (#12720) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1b2bd0ae6..35fa3f18f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.2 +comfy-aimdo>=0.2.3 requests #non essential dependencies: From 602f6bd82c1f8b31d1b10b5f9ae4aa9637772ad5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 1 Mar 2026 12:28:39 -0800 Subject: [PATCH 10/75] Make --disable-smart-memory disable dynamic vram. (#12722) --- comfy/cli_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 13079c7bc..bfb61c825 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -260,4 +260,4 @@ else: args.fast = set(args.fast) def enables_dynamic_vram(): - return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu + return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu and not args.disable_smart_memory From dfbf99a06172a5c54002d80abf3e74c0d82c10b9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Mar 2026 19:18:56 -0800 Subject: [PATCH 11/75] model_mangament: make dynamic --disable-smart-memory work (#12724) This was previously considering the pool of dynamic models as one giant entity for the sake of smart memory, but that isnt really the useful or what a user would reasonably expect. Make Dynamic VRAM properly purge its models just like the old --disable-smart-memory but conditioning the dynamic-for-dynamic bypass on smart memory. Re-enable dynamic smart memory. --- comfy/cli_args.py | 2 +- comfy/model_management.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index bfb61c825..13079c7bc 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -260,4 +260,4 @@ else: args.fast = set(args.fast) def enables_dynamic_vram(): - return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu and not args.disable_smart_memory + return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu diff --git a/comfy/model_management.py b/comfy/model_management.py index 86f840ada..c817d43b5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -639,12 +639,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ if not DISABLE_SMART_MEMORY: memory_to_free = memory_required - get_free_memory(device) ram_to_free = ram_required - get_free_ram() - - if current_loaded_models[i].model.is_dynamic() and for_dynamic: - #don't actually unload dynamic models for the sake of other dynamic models - #as that works on-demand. - memory_required -= current_loaded_models[i].model.loaded_size() - memory_to_free = 0 + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + memory_to_free = 0 if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) From 7175c11a4ed41278c9cb9e6961b8d8776ef69f00 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:21:41 -0800 Subject: [PATCH 12/75] comfy aimdo 0.2.4 (#12727) Comfy Aimdo 0.2.4 fixes a VRAM buffer alignment issue that happens in someworkflows where action is able to bypass the pytorch allocator and go straight to the cuda hook. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 35fa3f18f..71019c16f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.3 +comfy-aimdo>=0.2.4 requests #non essential dependencies: From afb54219fac341fa8614fdab090fe8096d0aec1e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 2 Mar 2026 09:24:33 +0200 Subject: [PATCH 13/75] feat(api-nodes): allow to use "IMAGE+TEXT" in NanoBanana2 (#12729) --- comfy_api_nodes/nodes_gemini.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 3fe804e0b..d83d2fc15 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -789,8 +789,6 @@ class GeminiImage2(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": model = "gemini-3.1-flash-image-preview" - if response_modalities == "IMAGE+TEXT": - raise ValueError("IMAGE+TEXT is not currently available for the Nano Banana 2 model.") parts: list[GeminiPart] = [GeminiPart(text=prompt)] if images is not None: @@ -895,7 +893,7 @@ class GeminiNanoBanana2(IO.ComfyNode): ), IO.Combo.Input( "response_modalities", - options=["IMAGE"], + options=["IMAGE", "IMAGE+TEXT"], advanced=True, ), IO.Combo.Input( @@ -925,6 +923,7 @@ class GeminiNanoBanana2(IO.ComfyNode): ], outputs=[ IO.Image.Output(), + IO.String.Output(), ], hidden=[ IO.Hidden.auth_token_comfy_org, From f1f8996e1562c3753666d1c568b2ff629edb9e36 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 3 Mar 2026 01:13:42 +0800 Subject: [PATCH 14/75] chore: update workflow templates to v0.9.5 (#12732) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 71019c16f..608b0cfa6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.4 +comfyui-workflow-templates==0.9.5 comfyui-embedded-docs==0.4.3 torch torchsde From 57dd6c1aadf500d90f635a8d3c15418c0d6d6ecd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 2 Mar 2026 15:54:18 -0800 Subject: [PATCH 15/75] Support loading zeta chroma weights properly. (#12734) --- comfy/model_detection.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3faa950ca..9f4a26e61 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config - if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 dit_config = {} dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 @@ -533,8 +533,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config - if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 - + if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1 dit_config = {} dit_config["image_model"] = "hunyuan3d2_1" dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] @@ -1055,6 +1054,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix) + elif 'noise_refiner.0.attention.norm_k.weight' in state_dict: + n_layers = count_blocks(state_dict, 'layers.{}.') + dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0] + sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix) + for k in state_dict: # For zeta chroma + if k not in sd_map: + sd_map[k] = k elif 'x_embedder.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') From 9ebee0a2179b361a24c20838c1848d7988320636 Mon Sep 17 00:00:00 2001 From: Lodestone Date: Tue, 3 Mar 2026 07:43:47 +0700 Subject: [PATCH 16/75] Feat: z-image pixel space (model still training atm) (#12709) * draft zeta (z-image pixel space) * revert gitignore * model loaded and able to run however vector direction still wrong tho * flip the vector direction to original again this time * Move wrongly positioned Z image pixel space class * inherit Radiance LatentFormat class * Fix parameters in classes for Zeta x0 dino * remove arbitrary nn.init instances * Remove unused import of lru_cache --------- Co-authored-by: silveroxides --- comfy/latent_formats.py | 7 + comfy/ldm/lumina/model.py | 265 ++++++++++++++++++++++++++++++++++++++ comfy/model_base.py | 5 + comfy/model_detection.py | 23 ++++ comfy/supported_models.py | 16 ++- 5 files changed, 315 insertions(+), 1 deletion(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f59999af6..6a57bca1c 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat): def process_out(self, latent): return latent + + +class ZImagePixelSpace(ChromaRadiance): + """Pixel-space latent format for ZImage DCT variant. + No VAE encoding/decoding — the model operates directly on RGB pixels. + """ + pass diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 77d1abc97..9e432d5c0 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension import comfy.utils +from comfy.ldm.chroma_radiance.layers import NerfEmbedder def invert_slices(slices, length): @@ -858,3 +859,267 @@ class NextDiT(nn.Module): img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] return -img + +############################################################################# +# Pixel Space Decoder Components # +############################################################################# + +def _modulate_shift_scale(x, shift, scale): + return x * (1 + scale) + shift + + +class PixelResBlock(nn.Module): + """ + Residual block with AdaLN modulation, zero-initialised so it starts as + an identity at the beginning of training. + """ + + def __init__(self, channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device) + self.mlp = nn.Sequential( + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) + h = _modulate_shift_scale(self.in_ln(x), shift, scale) + h = self.mlp(h) + return x + gate * h + + +class DCTFinalLayer(nn.Module): + """Zero-initialised output projection (adopted from DiT).""" + + def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.norm_final(x)) + + +class SimpleMLPAdaLN(nn.Module): + """ + Small MLP decoder head for the pixel-space variant. + + Takes per-patch pixel values and a per-patch conditioning vector from the + transformer backbone and predicts the denoised pixel values. + + x : [B*N, P^2, C] – noisy pixel values per patch position + c : [B*N, dim] – backbone hidden state per patch (conditioning) + → [B*N, P^2, C] + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + z_channels: int, + num_res_blocks: int, + max_freqs: int = 8, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + + # Project backbone hidden state → per-patch conditioning + self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device) + + # Input projection with DCT positional encoding + self.input_embedder = NerfEmbedder( + in_channels=in_channels, + hidden_size_input=model_channels, + max_freqs=max_freqs, + dtype=dtype, + device=device, + operations=operations, + ) + + # Residual blocks + self.res_blocks = nn.ModuleList([ + PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks) + ]) + + # Output projection + self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + # x: [B*N, 1, P^2*C], c: [B*N, dim] + original_dtype = x.dtype + weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype) + x = self.input_embedder(x) # [B*N, 1, model_channels] + y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels] + x = x.to(weight_dtype) + for block in self.res_blocks: + x = block(x, y) + return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C] + + +############################################################################# +# NextDiT – Pixel Space # +############################################################################# + +class NextDiTPixelSpace(NextDiT): + """ + Pixel-space variant of NextDiT. + + Identical transformer backbone to NextDiT, but the output head is replaced + with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values + per patch rather than a single affine projection. + + Key differences vs NextDiT: + • ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead. + • ``_forward`` stores the raw patchified pixel values before the backbone + embedding and feeds them to ``dec_net`` together with the per-patch + backbone hidden states. + • Supports optional x0 prediction via ``use_x0``. + """ + + def __init__( + self, + # decoder-specific + decoder_hidden_size: int = 3840, + decoder_num_res_blocks: int = 4, + decoder_max_freqs: int = 8, + decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels) + use_x0: bool = False, + # all NextDiT args forwarded unchanged + **kwargs, + ): + super().__init__(**kwargs) + + # Remove the latent-space final layer – not used in pixel space + del self.final_layer + + patch_size = kwargs.get("patch_size", 2) + in_channels = kwargs.get("in_channels", 4) + dim = kwargs.get("dim", 4096) + + # decoder_in_channels is the full flattened patch: patch_size^2 * in_channels + dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels + + self.dec_net = SimpleMLPAdaLN( + in_channels=dec_in_ch, + model_channels=decoder_hidden_size, + out_channels=dec_in_ch, + z_channels=dim, + num_res_blocks=decoder_num_res_blocks, + max_freqs=decoder_max_freqs, + dtype=kwargs.get("dtype"), + device=kwargs.get("device"), + operations=kwargs.get("operations"), + ) + + if use_x0: + self.register_buffer("__x0__", torch.tensor([])) + + # ------------------------------------------------------------------ + # Forward — mirrors NextDiT._forward exactly, replacing final_layer + # with the pixel-space dec_net decoder. + # ------------------------------------------------------------------ + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + + t = 1.0 - timesteps + cap_feats = context + cap_mask = attention_mask + bs, c, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) + adaln_input = t + + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + + # ---- capture raw pixel patches before patchify_and_embed embeds them ---- + pH = pW = self.patch_size + B, C, H, W = x.shape + pixel_patches = ( + x.view(B, C, H // pH, pH, W // pW, pW) + .permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C] + .flatten(3) # [B, Ht, Wt, pH*pW*C] + .flatten(1, 2) # [B, N, pH*pW*C] + ) + N = pixel_patches.shape[1] + # decoder sees one token per patch: [B*N, 1, P^2*C] + pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C) + + patches = transformer_options.get("patches", {}) + x_is_tensor = isinstance(x, torch.Tensor) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed( + x, cap_feats, cap_mask, adaln_input, num_tokens, + ref_latents=ref_latents, ref_contexts=ref_contexts, + siglip_feats=siglip_feats, transformer_options=transformer_options + ) + freqs_cis = freqs_cis.to(img.device) + + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" + img_input = img + for i, layer in enumerate(self.layers): + transformer_options["block_index"] = i + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] + + # ---- pixel-space decoder (replaces final_layer + unpatchify) ---- + # img may have padding tokens beyond N; only the first N are real image patches + img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim] + decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim] + + output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C] + output = output.reshape(B, N, -1) # [B, N, P^2*C] + + # prepend zero cap placeholder so unpatchify indexing works unchanged + cap_placeholder = torch.zeros( + B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype + ) + img_out = self.unpatchify( + torch.cat([cap_placeholder, output], dim=1), + img_size, cap_size, return_tensor=x_is_tensor + )[:, :, :h, :w] + + return -img_out + + def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + # _forward returns neg_x0 = -x0 (negated decoder output). + # + # Reference inference (working_inference_reference.py): + # out = _forward(img, t) # = -x0 + # pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual] + # img += (t_prev - t_curr) * pred # Euler step + # + # ComfyUI's Euler sampler does the same: + # x_next = x + (sigma_next - sigma) * model_output + # So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t + neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + return (x - neg_x0) / timesteps.view(-1, 1, 1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index a1c690b9b..1e01e9edc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1263,6 +1263,11 @@ class Lumina2(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class ZImagePixelSpace(Lumina2): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) + self.memory_usage_factor_conds = ("ref_latents",) + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 9f4a26e61..6eace4628 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -464,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if sig_weight is not None: dit_config["siglip_feat_dim"] = sig_weight.shape[0] + dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix) + if dec_cond_key in state_dict_keys: # pixel-space variant + dit_config["image_model"] = "zimage_pixel" + # patch_size and in_channels are derived from x_embedder: + # x_embedder: Linear(patch_size * patch_size * in_channels, dim) + # The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim. + x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] + dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0] + # patch_size: infer from decoder final layer output matching x_embedder input + # in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2) + embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)] + dec_in_ch = dec_out # decoder in == decoder out (same pixel space) + dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3) + dit_config["in_channels"] = 3 + dit_config["decoder_in_channels"] = dec_in_ch + dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0] + dit_config["decoder_num_res_blocks"] = count_blocks( + state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.' + ) + dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5) + if '{}__x0__'.format(key_prefix) in state_dict_keys: + dit_config["use_x0"] = True + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4f63e8327..c0d3f387f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1118,6 +1118,20 @@ class ZImage(Lumina2): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) +class ZImagePixelSpace(ZImage): + unet_config = { + "image_model": "zimage_pixel", + } + + # Pixel-space model: no spatial compression, operates on raw RGB patches. + latent_format = latent_formats.ZImagePixelSpace + + # Much lower memory than latent-space models (no VAE, small patches). + memory_usage_factor = 0.05 # TODO: figure out the optimal value for this. + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ZImagePixelSpace(self, device=device) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1720,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] From dff0a4a15887383c90a031e3fd48ebc41f6928e7 Mon Sep 17 00:00:00 2001 From: xeinherjer <112741359+xeinherjer-dev@users.noreply.github.com> Date: Tue, 3 Mar 2026 10:17:51 +0900 Subject: [PATCH 17/75] Fix VAEDecodeAudioTiled ignoring tile_size input (#12735) (#12738) --- comfy_extras/nodes_audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 43df0512f..5d8d9bf6f 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode): def vae_decode_audio(vae, samples, tile=None, overlap=None): if tile is not None: - audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1) + audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1) else: audio = vae.decode(samples["samples"]).movedim(-1, 1) From 09bcbddfcf804634f008f53c1827b7ba9a3956ec Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Mar 2026 08:50:33 -0800 Subject: [PATCH 18/75] ModelPatcherDynamic: Force load all non-comfy weights (#12739) * model_management: Remove non-comfy dynamic _v caster * Force pre-load non-comfy weights to GPU in ModelPatcherDynamic Non-comfy weights may expect to be pre-cast to the target device without in-model casting. Previously they were allocated in the vbar with _v which required the _v fault path in cast_to. Instead, back up the original CPU weight and move it directly to GPU at load time. --- comfy/model_management.py | 40 ------------------------------- comfy/model_patcher.py | 50 ++++++++++++++++----------------------- 2 files changed, 21 insertions(+), 69 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c817d43b5..0e0e96672 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -32,9 +32,6 @@ import comfy.memory_management import comfy.utils import comfy.quant_ops -import comfy_aimdo.torch -import comfy_aimdo.model_vbar - class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -1206,43 +1203,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): - if hasattr(weight, "_v"): - #Unexpected usage patterns. There is no reason these don't work but they - #have no testing and no callers do this. - assert r is None - assert stream is None - - cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ]) - - if dtype is None: - dtype = weight._model_dtype - - signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) - if signature is not None: - if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): - v_tensor = weight._v_tensor - else: - raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] - weight._v_tensor = v_tensor - weight._v_signature = signature - #Send it over - v_tensor.copy_(weight, non_blocking=non_blocking) - return v_tensor.to(dtype=dtype) - - r = torch.empty_like(weight, dtype=dtype, device=device) - - if weight.dtype != r.dtype and weight.dtype != weight._model_dtype: - #Offloaded casting could skip this, however it would make the quantizations - #inconsistent between loaded and offloaded weights. So force the double casting - #that would happen in regular flow to make offload deterministic. - cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device) - cast_buffer.copy_(weight, non_blocking=non_blocking) - weight = cast_buffer - r.copy_(weight, non_blocking=non_blocking) - - return r - if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3fc76d9db..e380e406b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1435,10 +1435,6 @@ class ModelPatcherDynamic(ModelPatcher): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): super().__init__(model, load_device, offload_device, size, weight_inplace_update) - #this is now way more dynamic and we dont support the same base model for both Dynamic - #and non-dynamic patchers. - if hasattr(self.model, "model_loaded_weight_memory"): - del self.model.model_loaded_weight_memory if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} self.non_dynamic_delegate_model = None @@ -1461,9 +1457,7 @@ class ModelPatcherDynamic(ModelPatcher): def loaded_size(self): vbar = self._vbar_get() - if vbar is None: - return 0 - return vbar.loaded_size() + return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory def get_free_memory(self, device): #NOTE: on high condition / batch counts, estimate should have already vacated @@ -1504,6 +1498,7 @@ class ModelPatcherDynamic(ModelPatcher): num_patches = 0 allocated_size = 0 + self.model.model_loaded_weight_memory = 0 with self.use_ejected(): self.unpatch_hooks() @@ -1512,10 +1507,6 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - #We force reserve VRAM for the non comfy-weight so we dont have to deal - #with pin and unpin syncrhonization which can be expensive for small weights - #with a high layer rate (e.g. autoregressive LLMs). - #prioritize the non-comfy weights (note the order reverse). loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) loading.sort(reverse=True) @@ -1558,6 +1549,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) self.patch_weight_to_device(key, device_to=device_to) + weight, _, _ = get_key_weight(self.model, key) + if weight is not None: + self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True @@ -1583,21 +1577,15 @@ class ModelPatcherDynamic(ModelPatcher): for param in params: key = key_param_name_to_key(n, param) weight, _, _ = get_key_weight(self.model, key) - weight.seed_key = key - set_dirty(weight, dirty) - geometry = weight - model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype - geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) - weight_size = geometry.numel() * geometry.element_size() - if vbar is not None and not hasattr(weight, "_v"): - weight._v = vbar.alloc(weight_size) - weight._model_dtype = model_dtype - allocated_size += weight_size - vbar.set_watermark_limit(allocated_size) + if key not in self.backup: + self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) + comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) + self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() move_weight_functions(m, device_to) - logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") self.model.device = device_to self.model.current_weight_patches_uuid = self.patches_uuid @@ -1613,7 +1601,16 @@ class ModelPatcherDynamic(ModelPatcher): assert self.load_device != torch.device("cpu") vbar = self._vbar_get() - return 0 if vbar is None else vbar.free_memory(memory_to_free) + freed = 0 if vbar is None else vbar.free_memory(memory_to_free) + + if freed < memory_to_free: + for key in list(self.backup.keys()): + bk = self.backup.pop(key) + comfy.utils.set_attr_param(self.model, key, bk.weight) + freed += self.model.model_loaded_weight_memory + self.model.model_loaded_weight_memory = 0 + + return freed def partially_unload_ram(self, ram_to_unload): loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) @@ -1640,11 +1637,6 @@ class ModelPatcherDynamic(ModelPatcher): for m in self.model.modules(): move_weight_functions(m, device_to) - keys = list(self.backup.keys()) - for k in keys: - bk = self.backup[k] - comfy.utils.set_attr_param(self.model, k, bk.weight) - def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above with self.use_ejected(skip_and_inject_on_exit_only=True): From 174fd6759deee5ea73e4cde4ba2936e8d62d8d66 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Mar 2026 08:51:15 -0800 Subject: [PATCH 19/75] main: Load aimdo after logger is setup (#12743) This was too early. Aimdo can use the logger in error paths and this causes a rogue default init if aimdo has something to log. --- main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index af701f8df..0f58d57b8 100644 --- a/main.py +++ b/main.py @@ -16,11 +16,6 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags -import comfy_aimdo.control - -if enables_dynamic_vram(): - comfy_aimdo.control.init() - if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' @@ -28,6 +23,11 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() + if os.name == "nt": os.environ['MIMALLOC_PURGE_DELAY'] = '0' From f719a9d928049f85b07b8ecc2259fba4832d37bb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Mar 2026 14:35:22 -0800 Subject: [PATCH 20/75] Adjust memory usage factor of zeta model. (#12746) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c0d3f387f..07feb31b3 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1127,7 +1127,7 @@ class ZImagePixelSpace(ZImage): latent_format = latent_formats.ZImagePixelSpace # Much lower memory than latent-space models (no VAE, small patches). - memory_usage_factor = 0.05 # TODO: figure out the optimal value for this. + memory_usage_factor = 0.03 # TODO: figure out the optimal value for this. def get_model(self, state_dict, prefix="", device=None): return model_base.ZImagePixelSpace(self, device=device) From b6ddc590ed8dafd50df8aad1e626b78276a690c0 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 3 Mar 2026 19:58:53 -0500 Subject: [PATCH 21/75] CURVE type (#12581) * CURVE type * fix: update typed wrapper unwrap keys to __type__ and __value__ * code improve * code improve --- comfy_api/latest/_io.py | 14 ++++++++++++++ execution.py | 14 ++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 189d7d9bc..050031dc0 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1240,6 +1240,19 @@ class BoundingBox(ComfyTypeIO): return d +@comfytype(io_type="CURVE") +class Curve(ComfyTypeIO): + CurvePoint = tuple[float, float] + Type = list[CurvePoint] + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = [(0.0, 0.0), (1.0, 1.0)] + + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -2226,5 +2239,6 @@ __all__ = [ "PriceBadgeDepends", "PriceBadge", "BoundingBox", + "Curve", "NodeReplace", ] diff --git a/execution.py b/execution.py index 75b021892..7ccdbf93e 100644 --- a/execution.py +++ b/execution.py @@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue else: try: - # Unwraps values wrapped in __value__ key. This is used to pass - # list widget value to execution, as by default list value is - # reserved to represent the connection between nodes. - if isinstance(val, dict) and "__value__" in val: - val = val["__value__"] - inputs[x] = val + # Unwraps values wrapped in __value__ key or typed wrapper. + # This is used to pass list widget values to execution, + # as by default list value is reserved to represent the + # connection between nodes. + if isinstance(val, dict): + if "__value__" in val: + val = val["__value__"] + inputs[x] = val if input_type == "INT": val = int(val) From ac6513e142f881202c40eacc5e337982b777ccd0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Mar 2026 18:19:40 -0800 Subject: [PATCH 22/75] DynamicVram: Add casting / fix torch Buffer weights (#12749) * respect model dtype in non-comfy caster * utils: factor out parent and name functionality of set_attr * utils: implement set_attr_buffer for torch buffers * ModelPatcherDynamic: Implement torch Buffer loading If there is a buffer in dynamic - force load it. --- comfy/model_management.py | 2 ++ comfy/model_patcher.py | 22 ++++++++++++++++++---- comfy/utils.py | 19 +++++++++++++++---- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0e0e96672..0f5966371 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -796,6 +796,8 @@ def archive_model_dtypes(model): for name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + for buf_name, buf in module.named_buffers(recurse=False): + setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype) def cleanup_models(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e380e406b..70f78a089 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -241,6 +241,7 @@ class ModelPatcher: self.patches = {} self.backup = {} + self.backup_buffers = {} self.object_patches = {} self.object_patches_backup = {} self.weight_wrapper_patches = {} @@ -309,7 +310,7 @@ class ModelPatcher: return comfy.model_management.get_free_memory(device) def get_clone_model_override(self): - return self.model, (self.backup, self.object_patches_backup, self.pinned) + return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ @@ -336,7 +337,7 @@ class ModelPatcher: n.force_cast_weights = self.force_cast_weights - n.backup, n.object_patches_backup, n.pinned = model_override[1] + n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1] # attachments n.attachments = {} @@ -1579,11 +1580,22 @@ class ModelPatcherDynamic(ModelPatcher): weight, _, _ = get_key_weight(self.model, key) if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) - comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) - self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() + model_dtype = getattr(m, param + "_comfy_model_dtype", None) + casted_weight = weight.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_param(self.model, key, casted_weight) + self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size() move_weight_functions(m, device_to) + for key, buf in self.model.named_buffers(recurse=True): + if key not in self.backup_buffers: + self.backup_buffers[key] = buf + module, buf_name = comfy.utils.resolve_attr(self.model, key) + model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None) + casted_buf = buf.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_buffer(self.model, key, casted_buf) + self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size() + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") @@ -1607,6 +1619,8 @@ class ModelPatcherDynamic(ModelPatcher): for key in list(self.backup.keys()): bk = self.backup.pop(key) comfy.utils.set_attr_param(self.model, key, bk.weight) + for key in list(self.backup_buffers.keys()): + comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) freed += self.model.model_loaded_weight_memory self.model.model_loaded_weight_memory = 0 diff --git a/comfy/utils.py b/comfy/utils.py index 0769cef44..6e1d14419 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): ATTR_UNSET={} -def set_attr(obj, attr, value): +def resolve_attr(obj, attr): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1], ATTR_UNSET) + return obj, attrs[-1] + +def set_attr(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) if value is ATTR_UNSET: - delattr(obj, attrs[-1]) + delattr(obj, name) else: - setattr(obj, attrs[-1], value) + setattr(obj, name, value) return prev def set_attr_param(obj, attr, value): return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) +def set_attr_buffer(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) + persistent = name not in getattr(obj, "_non_persistent_buffers_set", set()) + obj.register_buffer(name, value, persistent=persistent) + return prev + def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") From eb011733b6e4d8a9f7b67a1787d817bfc8c0a5b4 Mon Sep 17 00:00:00 2001 From: Arthur R Longbottom Date: Tue, 3 Mar 2026 21:29:00 -0800 Subject: [PATCH 23/75] Fix VideoFromComponents.save_to crash when writing to BytesIO (#12683) * Fix VideoFromComponents.save_to crash when writing to BytesIO When `get_container_format()` or `get_stream_source()` is called on a tensor-based video (VideoFromComponents), it calls `save_to(BytesIO())`. Since BytesIO has no file extension, `av.open` can't infer the output format and throws `ValueError: Could not determine output format`. The sibling class `VideoFromFile` already handles this correctly via `get_open_write_kwargs()`, which detects BytesIO and sets the format explicitly. `VideoFromComponents` just never got the same treatment. This surfaces when any downstream node validates the container format of a tensor-based video, like TopazVideoEnhance or any node that calls `validate_container_format_is_mp4()`. Three-line fix in `comfy_api/latest/_input_impl/video_types.py`. * Add docstring to save_to to satisfy CI coverage check --- comfy_api/latest/_input_impl/video_types.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index a3d48c87f..58a37c9e8 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput): codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, ): + """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: @@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput): extra_kwargs = {} if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value + elif isinstance(path, io.BytesIO): + # BytesIO has no file extension, so av.open can't infer the format. + # Default to mp4 since that's the only supported format anyway. + extra_kwargs["format"] = "mp4" with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: From d531e3fb2a885d675d5b6d3a496b4af5d9757af1 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:47:44 -0800 Subject: [PATCH 24/75] model_patcher: Improve dynamic offload heuristic (#12759) Define a threshold below which a weight loading takes priority. This actually makes the offload consistent with non-dynamic, because what happens, is when non-dynamic fills ints to_load list, it will fill-up any left-over pieces that could fix large weights with small weights and load them, even though they were lower priority. This actually improves performance because the timy weights dont cost any VRAM and arent worth the control overhead of the DMA etc. --- comfy/model_patcher.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 70f78a089..168ce8430 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -699,7 +699,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self, prio_comfy_cast_weights=False, default_device=None): + def _load_list(self, for_dynamic=False, default_device=None): loading = [] for n, m in self.model.named_modules(): default = False @@ -727,8 +727,13 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () - loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) + # Dynamic: small weights (<64KB) first, then larger weights prioritized by size. + # Non-dynamic: prioritize by module offload cost. + if for_dynamic: + sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem) + else: + sort_criteria = (module_offload_mem,) + loading.append(sort_criteria + (module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -1508,11 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) - loading.sort(reverse=True) + loading = self._load_list(for_dynamic=True, default_device=device_to) + loading.sort() for x in loading: - _, _, _, n, m, params = x + *_, module_mem, n, m, params = x def set_dirty(item, dirty): if dirty or not hasattr(item, "_v_signature"): @@ -1627,9 +1632,9 @@ class ModelPatcherDynamic(ModelPatcher): return freed def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) + loading = self._load_list(for_dynamic=True, default_device=self.offload_device) for x in loading: - _, _, _, _, m, _ = x + *_, m, _ = x ram_to_unload -= comfy.pinned_memory.unpin_memory(m) if ram_to_unload <= 0: return From 9b85cf955858b0aca6b7b30c30b404470ea0c964 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:49:13 -0800 Subject: [PATCH 25/75] Comfy Aimdo 0.2.5 + Fix offload performance in DynamicVram (#12754) * ops: dont unpin nothing This was calling into aimdo in the none case (offloaded weight). Whats worse, is aimdo syncs for unpinning an offloaded weight, as that is the corner case of a weight getting evicted by its own use which does require a sync. But this was heppening every offloaded weight causing slowdown. * mp: fix get_free_memory policy The ModelPatcherDynamic get_free_memory was deducting the model from to try and estimate the conceptual free memory with doing any offloading. This is kind of what the old memory_memory_required was estimating in ModelPatcher load logic, however in practical reality, between over-estimates and padding, the loader usually underloaded models enough such that sampling could send CFG +/- through together even when partially loaded. So don't regress from the status quo and instead go all in on the idea that offloading is less of an issue than debatching. Tell the sampler it can use everything. --- comfy/model_patcher.py | 14 +++++++------- comfy/ops.py | 4 ++-- requirements.txt | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 168ce8430..7e5ad7aa4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -307,7 +307,13 @@ class ModelPatcher: return self.model.lowvram_patch_counter def get_free_memory(self, device): - return comfy.model_management.get_free_memory(device) + #Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In + #the vast majority of setups a little bit of offloading on the giant model more + #than pays for CFG. So return everything both torch and Aimdo could give us + aimdo_mem = 0 + if comfy.memory_management.aimdo_enabled: + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) @@ -1465,12 +1471,6 @@ class ModelPatcherDynamic(ModelPatcher): vbar = self._vbar_get() return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory - def get_free_memory(self, device): - #NOTE: on high condition / batch counts, estimate should have already vacated - #all non-dynamic models so this is safe even if its not 100% true that this - #would all be avaiable for inference use. - return comfy.model_management.get_total_memory(device) - self.model_size() - #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. def pin_weight_to_device(self, key): diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..8275dd0a5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -269,8 +269,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream): return os, weight_a, bias_a = offload_stream device=None - #FIXME: This is not good RTTI - if not isinstance(weight_a, torch.Tensor): + #FIXME: This is really bad RTTI + if weight_a is not None and not isinstance(weight_a, torch.Tensor): comfy_aimdo.model_vbar.vbar_unpin(s._v) device = weight_a if os is None: diff --git a/requirements.txt b/requirements.txt index 608b0cfa6..110568cd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.4 +comfy-aimdo>=0.2.5 requests #non essential dependencies: From 0a7446ade4bbeecfaf36e9a70eeabbeb0f6e59ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:59:56 +0200 Subject: [PATCH 26/75] Pass tokens when loading text gen model for text generation (#12755) Co-authored-by: Jedrzej Kosinski --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index a9ad7c2d2..8bcd09582 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -428,7 +428,7 @@ class CLIP: def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): self.cond_stage_model.reset_clip_options() - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) From 8811db52db5d0aea49c1dbedd733a6b9304b83a9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 12:12:37 -0800 Subject: [PATCH 27/75] comfy-aimdo 0.2.6 (#12764) Comfy Aimdo 0.2.6 fixes a GPU virtual address leak. This would manfiest as an error after a number of workflow runs. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 110568cd3..dae46d873 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.5 +comfy-aimdo>=0.2.6 requests #non essential dependencies: From ac4a943ff364885166def5d418582db971554caf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:33:14 -0800 Subject: [PATCH 28/75] Initial load device should be cpu when using dynamic vram. (#12766) --- comfy/model_management.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0f5966371..809600815 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -830,11 +830,14 @@ def unet_offload_device(): return torch.device("cpu") def unet_inital_load_device(parameters, dtype): + cpu_dev = torch.device("cpu") + if comfy.memory_management.aimdo_enabled: + return cpu_dev + torch_dev = get_torch_device() if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: return torch_dev - cpu_dev = torch.device("cpu") if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: return cpu_dev @@ -842,7 +845,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: + if mem_dev > mem_cpu and model_size < mem_dev: return torch_dev else: return cpu_dev @@ -945,6 +948,9 @@ def text_encoder_device(): return torch.device("cpu") def text_encoder_initial_device(load_device, offload_device, model_size=0): + if comfy.memory_management.aimdo_enabled: + return offload_device + if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device From 43c64b6308f93c331f057e12799bad0a68be5117 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:06:20 -0800 Subject: [PATCH 29/75] Support the LTXAV 2.3 model. (#12773) --- comfy/ldm/lightricks/av_model.py | 185 ++++++- comfy/ldm/lightricks/embeddings_connector.py | 4 + comfy/ldm/lightricks/model.py | 186 ++++++- comfy/ldm/lightricks/vae/audio_vae.py | 7 +- .../vae/causal_audio_autoencoder.py | 67 +-- .../vae/causal_video_autoencoder.py | 48 +- comfy/ldm/lightricks/vocoders/vocoder.py | 523 +++++++++++++++++- comfy/model_base.py | 2 +- comfy/sd.py | 2 +- comfy/text_encoders/lt.py | 68 ++- 10 files changed, 959 insertions(+), 133 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 553fd5b38..08d686b7b 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -2,11 +2,16 @@ from typing import Tuple import torch import torch.nn as nn from comfy.ldm.lightricks.model import ( + ADALN_BASE_PARAMS_COUNT, + ADALN_CROSS_ATTN_PARAMS_COUNT, CrossAttention, FeedForward, AdaLayerNormSingle, PixArtAlphaTextProjection, + NormSingleLinearTextProjection, LTXVModel, + apply_cross_attention_adaln, + compute_prompt_timestep, ) from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module): v_context_dim=None, a_context_dim=None, attn_precision=None, + apply_gated_attention=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=v_dim, @@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=vd_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=ad_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module): heads=v_heads, dim_head=vd_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module): a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations ) - self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype)) self.audio_scale_shift_table = nn.Parameter( - torch.empty(6, a_dim, device=device, dtype=dtype) + torch.empty(num_ada_params, a_dim, device=device, dtype=dtype) ) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype)) + self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype)) + self.scale_shift_table_a2v_ca_audio = nn.Parameter( torch.empty(5, a_dim, device=device, dtype=dtype) ) @@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module): return (*scale_shift_ada_values, *gate_ada_values) + def _apply_text_cross_attention( + self, x, context, attn, scale_shift_table, prompt_scale_shift_table, + timestep, prompt_timestep, attention_mask, transformer_options, + ): + """Apply text cross-attention, with optional ADaLN modulation.""" + if self.cross_attention_adaln: + shift_q, scale_q, gate = self.get_ada_values( + scale_shift_table, x.shape[0], timestep, slice(6, 9) + ) + return apply_cross_attention_adaln( + x, context, attn, shift_q, scale_q, gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask, transformer_options, + ) + return attn( + comfy.ldm.common_dit.rms_norm(x), context=context, + mask=attention_mask, transformer_options=transformer_options, + ) + def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, + v_prompt_timestep=None, a_prompt_timestep=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module): vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] vx.addcmul_(attn1_out, vgate_msa) del vgate_msa, attn1_out - vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options)) + vx.add_(self._apply_text_cross_attention( + vx, v_context, self.attn2, self.scale_shift_table, + getattr(self, 'prompt_scale_shift_table', None), + v_timestep, v_prompt_timestep, attention_mask, transformer_options,) + ) # audio if run_ax: @@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module): agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] ax.addcmul_(attn1_out, agate_msa) del agate_msa, attn1_out - ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options)) + ax.add_(self._apply_text_cross_attention( + ax, a_context, self.audio_attn2, self.audio_scale_shift_table, + getattr(self, 'audio_prompt_scale_shift_table', None), + a_timestep, a_prompt_timestep, attention_mask, transformer_options,) + ) # video - audio cross attention. if run_a2v or run_v2a: @@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel): use_middle_indices_grid=False, timestep_scale_multiplier=1000.0, av_ca_timestep_scale_multiplier=1.0, + apply_gated_attention=False, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel): self.audio_attention_head_dim = audio_attention_head_dim self.audio_num_attention_heads = audio_num_attention_heads self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.apply_gated_attention = apply_gated_attention # Calculate audio dimensions self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim @@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel): ) # Audio-specific AdaLN + audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.audio_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, + embedding_coefficient=audio_embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations, ) + if self.cross_attention_adaln: + self.audio_prompt_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=2, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_prompt_adaln_single = None + num_scale_shift_values = 4 self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( self.inner_dim, @@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel): ) # Audio caption projection - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.audio_inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.audio_caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_caption_projection = lambda a: a + else: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + connector_split_rope = kwargs.get("rope_type", "split") == "split" + connector_gated_attention = kwargs.get("connector_apply_gated_attention", False) + attention_head_dim = kwargs.get("connector_attention_head_dim", 128) + num_attention_heads = kwargs.get("connector_num_attention_heads", 30) + num_layers = kwargs.get("connector_num_layers", 2) self.audio_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim), + num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads), + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) self.video_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) - def preprocess_text_embeds(self, context): - if context.shape[-1] == self.caption_channels * 2: - return context - out_vid = self.video_embeddings_connector(context)[0] - out_audio = self.audio_embeddings_connector(context)[0] + def preprocess_text_embeds(self, context, unprocessed=False): + # LTXv2 fully processed context has dimension of self.caption_channels * 2 + # LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim + if not unprocessed: + if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2): + return context + if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim: + context_vid = context[:, :, :self.cross_attention_dim] + context_audio = context[:, :, self.cross_attention_dim:] + else: + context_vid = context + context_audio = context + if self.caption_proj_before_connector: + context_vid = self.caption_projection(context_vid) + context_audio = self.audio_caption_projection(context_audio) + out_vid = self.video_embeddings_connector(context_vid)[0] + out_audio = self.audio_embeddings_connector(context_audio)[0] return torch.concat((out_vid, out_audio), dim=-1) def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel): ad_head=self.audio_attention_head_dim, v_context_dim=self.cross_attention_dim, a_context_dim=self.audio_cross_attention_dim, + apply_gated_attention=self.apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel): v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) + v_prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: @@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel): # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - a_timestep_flat, + timestep.max().expand_as(a_timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - timestep_flat, + a_timestep.max().expand_as(timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - timestep_flat * av_ca_factor, + a_timestep.max().expand_as(timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - a_timestep_flat * av_ca_factor, + timestep.max().expand_as(a_timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel): # Audio timesteps a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) + + a_prompt_timestep = compute_prompt_timestep( + self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype + ) else: a_timestep = timestep_scaled a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] + a_prompt_timestep = None - return [v_timestep, a_timestep, cross_av_timestep_ss], [ + return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [ v_embedded_timestep, a_embedded_timestep, - ] + ], None def _prepare_context(self, context, batch_size, x, attention_mask=None): vx = x[0] ax = x[1] + video_dim = vx.shape[-1] + audio_dim = ax.shape[-1] + + v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim + a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim + v_context, a_context = torch.split( - context, int(context.shape[-1] / 2), len(context.shape) - 1 + context, [v_context_dim, a_context_dim], len(context.shape) - 1 ) v_context, attention_mask = super()._prepare_context( v_context, batch_size, vx, attention_mask ) - if self.audio_caption_projection is not None: + if self.caption_proj_before_connector is False: a_context = self.audio_caption_projection(a_context) - a_context = a_context.view(batch_size, -1, ax.shape[-1]) + a_context = a_context.view(batch_size, -1, audio_dim) return [v_context, a_context], attention_mask @@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel): av_ca_v2a_gate_noise_timestep, ) = timestep[2] + v_prompt_timestep = timestep[3] + a_prompt_timestep = timestep[4] + """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), + v_prompt_timestep=args.get("v_prompt_timestep"), + a_prompt_timestep=args.get("a_prompt_timestep"), ) return out @@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel): "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, + "v_prompt_timestep": v_prompt_timestep, + "a_prompt_timestep": a_prompt_timestep, }, {"original_block": block_wrap}, ) @@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + v_prompt_timestep=v_prompt_timestep, + a_prompt_timestep=a_prompt_timestep, ) return [vx, ax] diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index 33adb9671..2811080be 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module): d_head, context_dim=None, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module): heads=n_heads, dim_head=d_head, context_dim=None, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module): positional_embedding_max_pos=[4096], causal_temporal_positioning=False, num_learnable_registers: Optional[int] = 128, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module): num_attention_heads, attention_head_dim, context_dim=cross_attention_dim, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 60d760d29..bfbc08357 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module): return hidden_states +class NormSingleLinearTextProjection(nn.Module): + """Text projection for 20B models - single linear with RMSNorm (no activation).""" + + def __init__( + self, in_features, hidden_size, dtype=None, device=None, operations=None + ): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.in_norm = operations.RMSNorm( + in_features, eps=1e-6, elementwise_affine=False + ) + self.linear_1 = operations.Linear( + in_features, hidden_size, bias=True, dtype=dtype, device=device + ) + self.hidden_size = hidden_size + self.in_features = in_features + + def forward(self, caption): + caption = self.in_norm(caption) + caption = caption * (self.hidden_size / self.in_features) ** 0.5 + return self.linear_1(caption) + + class GELU_approx(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): super().__init__() @@ -343,6 +367,7 @@ class CrossAttention(nn.Module): dim_head=64, dropout=0.0, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -362,6 +387,12 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) + # Optional per-head gating + if apply_gated_attention: + self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device) + else: + self.to_gate_logits = None + self.to_out = nn.Sequential( operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -383,16 +414,30 @@ class CrossAttention(nn.Module): out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) + + # Apply per-head gating if enabled + if self.to_gate_logits is not None: + gate_logits = self.to_gate_logits(x) # (B, T, H) + b, t, _ = out.shape + out = out.view(b, t, self.heads, self.dim_head) + gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity + out = out * gates.unsqueeze(-1) + out = out.view(b, t, self.heads * self.dim_head) + return self.to_out(out) +# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate) +ADALN_BASE_PARAMS_COUNT = 6 +ADALN_CROSS_ATTN_PARAMS_COUNT = 9 class BasicTransformerBlock(nn.Module): def __init__( - self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None ): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, @@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module): operations=operations, ) - self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype)) - attn1_input = comfy.ldm.common_dit.rms_norm(x) - attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) - x.addcmul_(attn1_input, gate_msa) - del attn1_input + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2) - x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa + + if self.cross_attention_adaln: + shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2) + x += apply_cross_attention_adaln( + x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca, + self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options, + ) + else: + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) @@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module): return x +def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype): + """Compute a single global prompt timestep for cross-attention ADaLN. + + Uses the max across tokens (matching JAX max_per_segment) and broadcasts + over text tokens. Returns None when *adaln_module* is None. + """ + if adaln_module is None: + return None + ts_input = ( + timestep_scaled.max(dim=1, keepdim=True).values.flatten() + if timestep_scaled.dim() > 1 + else timestep_scaled.flatten() + ) + prompt_ts, _ = adaln_module( + ts_input, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1]) + + +def apply_cross_attention_adaln( + x, context, attn, q_shift, q_scale, q_gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask=None, transformer_options={}, +): + """Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV). + + Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so + that both regular tensors and CompressedTimestep are supported. + """ + batch_size = x.shape[0] + shift_kv, scale_kv = ( + prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) + ).unbind(dim=2) + attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift + encoder_hidden_states = context * (1 + scale_kv) + shift_kv + return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate + def get_fractional_positions(indices_grid, max_pos): n_pos_dims = indices_grid.shape[1] assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' @@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC): vae_scale_factors: tuple = (8, 32, 32), use_middle_indices_grid=False, timestep_scale_multiplier = 1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, + caption_projection_first_linear=True, dtype=None, device=None, operations=None, @@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC): self.causal_temporal_positioning = causal_temporal_positioning self.operations = operations self.timestep_scale_multiplier = timestep_scale_multiplier + self.caption_proj_before_connector = caption_proj_before_connector + self.cross_attention_adaln = cross_attention_adaln + self.caption_projection_first_linear = caption_projection_first_linear # Common dimensions self.inner_dim = num_attention_heads * attention_head_dim @@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC): self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device ) + embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.cross_attention_adaln: + self.prompt_adaln_single = AdaLayerNormSingle( + self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + ) + else: + self.prompt_adaln_single = None + + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.caption_projection = lambda a: a + else: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) @abstractmethod def _init_model_components(self, device, dtype, **kwargs): @@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC): if grid_mask is not None: timestep = timestep[:, grid_mask] - timestep = timestep * self.timestep_scale_multiplier + timestep_scaled = timestep * self.timestep_scale_multiplier timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), + timestep_scaled.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC): timestep = timestep.view(batch_size, -1, timestep.shape[-1]) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) - return timestep, embedded_timestep + prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + + return timestep, embedded_timestep, prompt_timestep def _prepare_context(self, context, batch_size, x, attention_mask=None): """Prepare context for transformer blocks.""" - if self.caption_projection is not None: + if self.caption_proj_before_connector is False: context = self.caption_projection(context) - context = context.view(batch_size, -1, x.shape[-1]) + context = context.view(batch_size, -1, x.shape[-1]) return context, attention_mask def _precompute_freqs_cis( @@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC): merged_args.update(additional_args) # Prepare timestep and context - timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + merged_args["prompt_timestep"] = prompt_timestep context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) # Prepare attention mask and positional embeddings @@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel): causal_temporal_positioning=False, vae_scale_factors=(8, 32, 32), use_middle_indices_grid=False, - timestep_scale_multiplier = 1000.0, + timestep_scale_multiplier=1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel): def _init_model_components(self, device, dtype, **kwargs): """Initialize LTXV-specific components.""" - # No additional components needed for LTXV beyond base class pass def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel): self.num_attention_heads, self.attention_head_dim, context_dim=self.cross_attention_dim, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prompt_timestep = kwargs.get("prompt_timestep", None) for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask")) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel): pe=pe, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + prompt_timestep=prompt_timestep, ) return x diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index 55a074661..fa0a00748 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( CausalityAxis, CausalAudioAutoencoder, ) -from comfy.ldm.lightricks.vocoders.vocoder import Vocoder +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE LATENT_DOWNSAMPLE_FACTOR = 4 @@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module): vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) - self.vocoder = Vocoder(config=component_config.vocoder) + if "bwe" in component_config.vocoder: + self.vocoder = VocoderWithBWE(config=component_config.vocoder) + else: + self.vocoder = Vocoder(config=component_config.vocoder) self.autoencoder.load_state_dict(vae_sd, strict=False) self.vocoder.load_state_dict(vocoder_sd, strict=False) diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py index f12b9bb53..b556b128f 100644 --- a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module): super().__init__() if config is None: - config = self._guess_config() + config = self.get_default_config() - # Extract encoder and decoder configs from the new format model_config = config.get("model", {}).get("params", {}) - variables_config = config.get("variables", {}) - self.sampling_rate = variables_config.get( - "sampling_rate", - model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + self.sampling_rate = model_config.get( + "sampling_rate", config.get("sampling_rate", 16000) ) encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) decoder_config = model_config.get("decoder", encoder_config) # Load mel spectrogram parameters self.mel_bins = encoder_config.get("mel_bins", 64) - self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) - self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) # Store causality configuration at VAE level (not just in encoder internals) - causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value) self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) self.is_causal = self.causality_axis == CausalityAxis.HEIGHT @@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module): self.per_channel_statistics = processor() - def _guess_config(self): - encoder_config = { - # Required parameters - based on ltx-video-av-1679000 model metadata - "ch": 128, - "out_ch": 8, - "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] - "num_res_blocks": 2, - "attn_resolutions": [], # Based on metadata: empty list, no attention - "dropout": 0.0, - "resamp_with_conv": True, - "in_channels": 2, # stereo - "resolution": 256, - "z_channels": 8, + def get_default_config(self): + ddconfig = { "double_z": True, - "attn_type": "vanilla", - "mid_block_add_attention": False, # Based on metadata: false + "mel_bins": 64, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 2, + "out_ch": 2, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + "mid_block_add_attention": False, "norm_type": "pixel", - "causality_axis": "height", # Based on metadata - "mel_bins": 64, # Based on metadata: mel_bins = 64 - } - - decoder_config = { - # Inherits encoder config, can override specific params - **encoder_config, - "out_ch": 2, # Stereo audio output (2 channels) - "give_pre_end": False, - "tanh_out": False, + "causality_axis": "height", } config = { - "_class_name": "CausalAudioAutoencoder", - "sampling_rate": 16000, "model": { "params": { - "encoder": encoder_config, - "decoder": decoder_config, + "ddconfig": ddconfig, + "sampling_rate": 16000, } }, + "preprocessing": { + "stft": { + "filter_length": 1024, + "hop_length": 160, + }, + }, } return config diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index cbfdf412d..5b57dfc5e 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +def in_meta_context(): + return torch.device("meta") == torch.empty(0).device + def mark_conv3d_ended(module): tid = threading.get_ident() for _, m in module.named_modules(): @@ -350,6 +353,10 @@ class Decoder(nn.Module): output_channel = output_channel * block_params.get("multiplier", 2) if block_name == "compress_all": output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_space": + output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_time": + output_channel = output_channel * block_params.get("multiplier", 1) self.conv_in = make_conv_nd( dims, @@ -395,17 +402,21 @@ class Decoder(nn.Module): spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(2, 1, 1), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(1, 2, 2), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": @@ -455,6 +466,15 @@ class Decoder(nn.Module): output_channel * 2, 0, operations=ops, ) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + else: + self.register_buffer( + "last_scale_shift_table", + torch.tensor( + [0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(2, output_channel), + persistent=False, + ) # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: @@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module): self.scale_shift_table = nn.Parameter( torch.randn(4, in_channels) / in_channels**0.5 ) + else: + self.register_buffer( + "scale_shift_table", + torch.tensor( + [0.0, 0.0, 0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(4, in_channels), + persistent=False, + ) self.temporal_cache_state={} @@ -1012,9 +1041,6 @@ class processor(nn.Module): super().__init__() self.register_buffer("std-of-means", torch.empty(128)) self.register_buffer("mean-of-means", torch.empty(128)) - self.register_buffer("mean-of-stds", torch.empty(128)) - self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128)) - self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) @@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module): super().__init__() if config is None: - config = self.guess_config(version) + config = self.get_default_config(version) + self.config = config self.timestep_conditioning = config.get("timestep_conditioning", False) + self.decode_noise_scale = config.get("decode_noise_scale", 0.025) + self.decode_timestep = config.get("decode_timestep", 0.05) double_z = config.get("double_z", True) latent_log_var = config.get( "latent_log_var", "per_channel" if double_z else "none" @@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module): latent_log_var=latent_log_var, norm_layer=config.get("norm_layer", "group_norm"), spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + base_channels=config.get("encoder_base_channels", 128), ) self.decoder = Decoder( @@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module): in_channels=config["latent_channels"], out_channels=config.get("out_channels", 3), blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), + base_channels=config.get("decoder_base_channels", 128), patch_size=config.get("patch_size", 1), norm_layer=config.get("norm_layer", "group_norm"), causal=config.get("causal_decoder", False), @@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module): self.per_channel_statistics = processor() - def guess_config(self, version): + def get_default_config(self, version): if version == 0: config = { "_class_name": "CausalVideoAutoencoder", @@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module): means, logvar = torch.chunk(self.encoder(x), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x, timestep=0.05, noise_scale=0.025): + def decode(self, x): if self.timestep_conditioning: #TODO: seed - x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x - return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep) - + x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index b1f15f2c5..6c4028aa8 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import torch.nn as nn import comfy.ops import numpy as np +import math ops = comfy.ops.disable_weight_init @@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) +# --------------------------------------------------------------------------- +# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 +# Adopted from https://github.com/NVIDIA/BigVGAN +# --------------------------------------------------------------------------- + + +def _sinc(x: torch.Tensor): + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride=1, + padding=True, + padding_mode="replicate", + kernel_size=12, + ): + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + def forward(self, x): + _, C, _ = x.shape + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"): + super().__init__() + self.ratio = ratio + self.stride = ratio + + if window_type == "hann": + # Hann-windowed sinc filter — identical to torchaudio.functional.resample + # with its default parameters (rolloff=0.99, lowpass_filter_width=6). + # Uses replicate boundary padding, matching the reference resampler exactly. + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + t = (torch.arange(self.kernel_size) / ratio - width) * rolloff + t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2 + filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser-windowed sinc filter (BigVGAN default). + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + + self.register_buffer("filter", filter, persistent=persistent) + + def forward(self, x): + _, C, _ = x.shape + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + return self.lowpass(x) + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio=2, + down_ratio=2, + up_kernel_size=12, + down_kernel_size=12, + ): + super().__init__() + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +# --------------------------------------------------------------------------- +# BigVGAN v2 activations (Snake / SnakeBeta) +# --------------------------------------------------------------------------- + + +class Snake(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + a = torch.exp(a) + return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) + + +class SnakeBeta(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.beta = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.beta.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1) + b = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + a = torch.exp(a) + b = torch.exp(b) + return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2) + + +# --------------------------------------------------------------------------- +# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity) +# --------------------------------------------------------------------------- + + +class AMPBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"): + super().__init__() + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + self.acts1 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs1))] + ) + self.acts2 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs2))] + ) + + def forward(self, x): + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = x + xt + return x + + +# --------------------------------------------------------------------------- +# HiFi-GAN residual blocks +# --------------------------------------------------------------------------- + + class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() @@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module): """ Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1"). """ def __init__(self, config=None): @@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module): config = self.get_default_config() resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) - upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) - upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4]) resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) upsample_initial_channel = config.get("upsample_initial_channel", 1024) stereo = config.get("stereo", True) - resblock = config.get("resblock", "1") + activation = config.get("activation", "snake") + use_bias_at_final = config.get("use_bias_at_final", True) + + # "output_sample_rate" is not present in recent checkpoint configs. + # When absent (None), AudioVAE.output_sample_rate computes it as: + # sample_rate * vocoder.upsample_factor / mel_hop_length + # where upsample_factor = product of all upsample stride lengths, + # and mel_hop_length is loaded from the autoencoder config at + # preprocessing.stft.hop_length (see CausalAudioAutoencoder). self.output_sample_rate = config.get("output_sample_rate") + self.resblock = config.get("resblock", "1") + self.use_tanh_at_final = config.get("use_tanh_at_final", True) + self.apply_final_activation = config.get("apply_final_activation", True) self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) - resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + if self.resblock == "1": + resblock_cls = ResBlock1 + elif self.resblock == "2": + resblock_cls = ResBlock2 + elif self.resblock == "AMP1": + resblock_cls = AMPBlock1 + else: + raise ValueError(f"Unknown resblock type: {self.resblock}") self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): @@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module): self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock_class(ch, k, d)) + for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): + if self.resblock == "AMP1": + self.resblocks.append(resblock_cls(ch, k, d, activation=activation)) + else: + self.resblocks.append(resblock_cls(ch, k, d)) out_channels = 2 if stereo else 1 - self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + if self.resblock == "AMP1": + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.act_post = Activation1d(act_cls(ch)) + else: + self.act_post = nn.LeakyReLU() + + self.conv_post = ops.Conv1d( + ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final + ) self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + def get_default_config(self): """Generate default configuration for the vocoder.""" config = { "resblock_kernel_sizes": [3, 7, 11], - "upsample_rates": [6, 5, 2, 2, 2], - "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "upsample_initial_channel": 1024, "stereo": True, "resblock": "1", + "activation": "snake", + "use_bias_at_final": True, + "use_tanh_at_final": True, } return config @@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module): assert x.shape[1] == 2, "Input must have 2 channels for stereo" x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) x = self.conv_pre(x) + for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) + if self.resblock != "AMP1": + x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) xs = None for j in range(self.num_kernels): @@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module): else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels - x = F.leaky_relu(x) + + x = self.act_post(x) x = self.conv_post(x) - x = torch.tanh(x) + + if self.apply_final_activation: + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, -1, 1) return x + + +class _STFTFn(nn.Module): + """Implements STFT as a convolution with precomputed DFT × Hann-window bases. + + The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal + Hann window are stored as buffers and loaded from the checkpoint. Using the exact + bfloat16 bases from training ensures the mel values fed to the BWE generator are + bit-identical to what it was trained on. + """ + + def __init__(self, filter_length: int, hop_length: int, win_length: int): + super().__init__() + self.hop_length = hop_length + self.win_length = win_length + n_freqs = filter_length // 2 + 1 + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + + def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute magnitude and phase spectrogram from a batch of waveforms. + + Applies causal (left-only) padding of win_length - hop_length samples so that + each output frame depends only on past and present input — no lookahead. + The STFT is computed by convolving the padded signal with forward_basis. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + Computed in float32 for numerical stability, then cast back to + the input dtype. + """ + if y.dim() == 2: + y = y.unsqueeze(1) # (B, 1, T) + left_pad = max(0, self.win_length - self.hop_length) # causal: left-only + y = F.pad(y, (left_pad, 0)) + spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real ** 2 + imag ** 2) + phase = torch.atan2(imag.float(), real.float()).to(real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. + + Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input + waveform and projecting the linear magnitude spectrum onto the mel filterbank. + + The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint + (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). + """ + + def __init__( + self, + filter_length: int, + hop_length: int, + win_length: int, + n_mel_channels: int, + sampling_rate: int, + mel_fmin: float, + mel_fmax: float, + ): + super().__init__() + self.stft_fn = _STFTFn(filter_length, hop_length, win_length) + + n_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) + + def mel_spectrogram( + self, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute log-mel spectrogram and auxiliary spectral quantities. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). + Computed as log(clamp(mel_basis @ magnitude, min=1e-5)). + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). + """ + magnitude, phase = self.stft_fn(y) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class VocoderWithBWE(torch.nn.Module): + """Vocoder with bandwidth extension (BWE) for higher sample rate output. + + Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples + to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. + """ + + def __init__(self, config): + super().__init__() + vocoder_config = config["vocoder"] + bwe_config = config["bwe"] + + self.vocoder = Vocoder(config=vocoder_config) + self.bwe_generator = Vocoder( + config={**bwe_config, "apply_final_activation": False} + ) + + self.input_sample_rate = bwe_config["input_sampling_rate"] + self.output_sample_rate = bwe_config["output_sampling_rate"] + self.hop_length = bwe_config["hop_length"] + + self.mel_stft = MelSTFT( + filter_length=bwe_config["n_fft"], + hop_length=bwe_config["hop_length"], + win_length=bwe_config["n_fft"], + n_mel_channels=bwe_config["num_mels"], + sampling_rate=bwe_config["input_sampling_rate"], + mel_fmin=0.0, + mel_fmax=bwe_config["input_sampling_rate"] / 2.0, + ) + self.resampler = UpSample1d( + ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"], + persistent=False, + window_type="hann", + ) + + def _compute_mel(self, audio): + """Compute log-mel spectrogram from waveform using causal STFT bases.""" + B, C, T = audio.shape + flat = audio.reshape(B * C, -1) # (B*C, T) + mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) + return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) + + def forward(self, mel_spec): + x = self.vocoder(mel_spec) + _, _, T_low = x.shape + T_out = T_low * self.output_sample_rate // self.input_sample_rate + + remainder = T_low % self.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + mel = self._compute_mel(x) + residual = self.bwe_generator(mel) + skip = self.resampler(x) + assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" + + return torch.clamp(residual + skip, -1, 1)[..., :T_out] diff --git a/comfy/model_base.py b/comfy/model_base.py index 1e01e9edc..d9d5a9293 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1021,7 +1021,7 @@ class LTXAV(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: if hasattr(self.diffusion_model, "preprocess_text_embeds"): - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference())) + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False)) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) diff --git a/comfy/sd.py b/comfy/sd.py index 8bcd09582..888ef1e77 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1467,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage elif clip_type == CLIPType.LTXV: - clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e86ea9f4e..5e1273c6e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel): comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is +class DualLinearProjection(torch.nn.Module): + def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None): + super().__init__() + self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device) + self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device) + + def forward(self, x): + source_dim = x.shape[-1] + x = x.movedim(1, -1) + x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2) + + video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim)) + audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim)) + return torch.cat((video, audio), dim=-1) + class LTXAVTEModel(torch.nn.Module): - def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}): super().__init__() self.dtypes = set() self.dtypes.add(dtype) self.compat_mode = False + self.text_projection_type = text_projection_type self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) self.dtypes.add(dtype_llama) operations = self.gemma3_12b.operations # TODO - self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + if self.text_projection_type == "single_linear": + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + elif self.text_projection_type == "dual_linear": + self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations) + def enable_compat_mode(self): # TODO: remove from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module): out_device = out.device if comfy.model_management.should_use_bf16(self.execution_device): out = out.to(device=self.execution_device, dtype=torch.bfloat16) - out = out.movedim(1, -1).to(self.execution_device) - out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) - out = out.reshape((out.shape[0], out.shape[1], -1)) - out = self.text_embedding_projection(out) - out = out.float() - if self.compat_mode: - out_vid = self.video_embeddings_connector(out)[0] - out_audio = self.audio_embeddings_connector(out)[0] - out = torch.concat((out_vid, out_audio), dim=-1) + if self.text_projection_type == "single_linear": + out = out.movedim(1, -1).to(self.execution_device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) - return out.to(out_device), pooled + if self.compat_mode: + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + extra = {} + else: + extra = {"unprocessed_ltxav_embeds": True} + elif self.text_projection_type == "dual_linear": + out = self.text_embedding_projection(out) + extra = {"unprocessed_ltxav_embeds": True} + + return out.to(device=out_device, dtype=torch.float), pooled, extra def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) @@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module): if "model.layers.47.self_attn.q_norm.weight" in sd: return self.gemma3_12b.load_sd(sd) else: - sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True) + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True) if len(sdo) == 0: sdo = sd @@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module): num_tokens = max(num_tokens, 642) return num_tokens * constant * 1024 * 1024 -def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"): class LTXAVTEModel_(LTXAVTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: @@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama - super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options) return LTXAVTEModel_ + +def sd_detect(state_dict_list, prefix=""): + for sd in state_dict_list: + if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd: + return {"text_projection_type": "dual_linear"} + if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd: + return {"text_projection_type": "single_linear"} + return {} + + def gemma3_te(dtype_llama=None, llama_quantization_metadata=None): class Gemma3_12BModel_(Gemma3_12BModel): def __init__(self, device="cpu", dtype=None, model_options={}): From f2ee7f2d367f98bb8a33bcb4a224bda441eb8a07 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 22:21:55 -0800 Subject: [PATCH 30/75] Fix cublas ops on dynamic vram. (#12776) --- comfy/ops.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 8275dd0a5..3e19cd1b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -660,23 +660,29 @@ class fp8_ops(manual_cast): CUBLAS_IS_AVAILABLE = False try: - from cublas_ops import CublasLinear + from cublas_ops import CublasLinear, cublas_half_matmul CUBLAS_IS_AVAILABLE = True except ImportError: pass if CUBLAS_IS_AVAILABLE: - class cublas_ops(disable_weight_init): - class Linear(CublasLinear, disable_weight_init.Linear): + class cublas_ops(manual_cast): + class Linear(CublasLinear, manual_cast.Linear): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - return super().forward(input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) # ============================================================================== # Mixed Precision Operations From c5fe8ace68c432a262a5093bdd84b3ed70b9d283 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 15:37:35 +0800 Subject: [PATCH 31/75] chore: update workflow templates to v0.9.6 (#12778) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index dae46d873..5f99407b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.5 +comfyui-workflow-templates==0.9.6 comfyui-embedded-docs==0.4.3 torch torchsde From 4941671b5a5c65fea48be922caa76b7f6a0a4595 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:39:51 -0800 Subject: [PATCH 32/75] Fix cuda getting initialized in cpu mode. (#12779) --- comfy/model_management.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 809600815..ee28ea107 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1666,12 +1666,16 @@ def lora_compute_dtype(device): return dtype def synchronize(): + if cpu_mode(): + return if is_intel_xpu(): torch.xpu.synchronize() elif torch.cuda.is_available(): torch.cuda.synchronize() def soft_empty_cache(force=False): + if cpu_mode(): + return global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() From c8428541a6b6e4b1e0fbd685e9c846efcb60179e Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 16:58:25 +0800 Subject: [PATCH 33/75] chore: update workflow templates to v0.9.7 (#12780) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5f99407b7..866818e08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.6 +comfyui-workflow-templates==0.9.7 comfyui-embedded-docs==0.4.3 torch torchsde From e04d0dbeb8266aa9262b5a4c3934ba4e4a371e37 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 04:06:29 -0500 Subject: [PATCH 34/75] ComfyUI v0.16.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 6a35c6de3..0aea18d3a 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.15.1" +__version__ = "0.16.0" diff --git a/pyproject.toml b/pyproject.toml index 1b2318273..f2133d99c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.15.1" +version = "0.16.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From bd21363563ce8e312c9271a0c64a0145335df8a9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:29:39 +0200 Subject: [PATCH 35/75] feat(api-nodes-xAI): updated models, pricing, added features (#12756) --- comfy_api_nodes/apis/grok.py | 14 ++++-- comfy_api_nodes/nodes_grok.py | 92 +++++++++++++++++++++++++++++------ 2 files changed, 87 insertions(+), 19 deletions(-) diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py index 8e3c79ab9..c56c8aecc 100644 --- a/comfy_api_nodes/apis/grok.py +++ b/comfy_api_nodes/apis/grok.py @@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel): aspect_ratio: str = Field(...) n: int = Field(...) seed: int = Field(...) - response_for: str = Field("url") + response_format: str = Field("url") + resolution: str = Field(...) class InputUrlObject(BaseModel): @@ -16,12 +17,13 @@ class InputUrlObject(BaseModel): class ImageEditRequest(BaseModel): model: str = Field(...) - image: InputUrlObject = Field(...) + images: list[InputUrlObject] = Field(...) prompt: str = Field(...) resolution: str = Field(...) n: int = Field(...) seed: int = Field(...) - response_for: str = Field("url") + response_format: str = Field("url") + aspect_ratio: str | None = Field(...) class VideoGenerationRequest(BaseModel): @@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel): revised_prompt: str | None = Field(None) +class UsageObject(BaseModel): + cost_in_usd_ticks: int | None = Field(None) + + class ImageGenerationResponse(BaseModel): data: list[ImageResponseObject] = Field(...) + usage: UsageObject | None = Field(None) class VideoGenerationResponse(BaseModel): @@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel): status: str | None = Field(None) video: VideoResponseObject | None = Field(None) model: str | None = Field(None) + usage: UsageObject | None = Field(None) diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index da15e97ea..0716d6239 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -27,6 +27,12 @@ from comfy_api_nodes.util import ( ) +def _extract_grok_price(response) -> float | None: + if response.usage and response.usage.cost_in_usd_ticks is not None: + return response.usage.cost_in_usd_ticks / 10_000_000_000 + return None + + class GrokImageNode(IO.ComfyNode): @classmethod @@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode): category="api node/image/Grok", description="Generate images using Grok based on a text prompt", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), IO.String.Input( "prompt", multiline=True, @@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), + IO.Combo.Input("resolution", options=["1K", "2K"], optional=True), ], outputs=[ IO.Image.Output(), @@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), - expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""", + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": $rate * widgets.number_of_images} + ) + """, ), ) @@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode): aspect_ratio: str, number_of_images: int, seed: int, + resolution: str = "1K", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) response = await sync_op( @@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode): aspect_ratio=aspect_ratio, n=number_of_images, seed=seed, + resolution=resolution.lower(), ), response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) @@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode): category="api node/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-image-beta"]), - IO.Image.Input("image"), + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), + IO.Image.Input("image", display_name="images"), IO.String.Input( "prompt", multiline=True, tooltip="The text prompt used to generate the image", ), - IO.Combo.Input("resolution", options=["1K"]), + IO.Combo.Input("resolution", options=["1K", "2K"]), IO.Int.Input( "number_of_images", default=1, @@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "9:16", + "16:9", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + ], + optional=True, + tooltip="Only allowed when multiple images are connected to the image input.", + ), ], outputs=[ IO.Image.Output(), @@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), - expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""", + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": 0.002 + $rate * widgets.number_of_images} + ) + """, ), ) @@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode): resolution: str, number_of_images: int, seed: int, + aspect_ratio: str = "auto", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) - if get_number_of_images(image) != 1: - raise ValueError("Only one input image is supported.") + if model == "grok-imagine-image-pro": + if get_number_of_images(image) > 1: + raise ValueError("The pro model supports only 1 input image.") + elif get_number_of_images(image) > 3: + raise ValueError("A maximum of 3 input images is supported.") + if aspect_ratio != "auto" and get_number_of_images(image) == 1: + raise ValueError( + "Custom aspect ratio is only allowed when multiple images are connected to the image input." + ) response = await sync_op( cls, ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), data=ImageEditRequest( model=model, - image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"), + images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image], prompt=prompt, resolution=resolution.lower(), n=number_of_images, seed=seed, + aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, ), response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) @@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode): category="api node/video/Grok", description="Generate video from a prompt or an image", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), IO.String.Input( "prompt", multiline=True, @@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]), + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), expr=""" ( - $base := 0.181 * widgets.duration; + $rate := widgets.resolution = "720p" ? 0.07 : 0.05; + $base := $rate * widgets.duration; {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} ) """, @@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) @@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode): category="api node/video/Grok", description="Edit an existing video based on a text prompt.", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), IO.String.Input( "prompt", multiline=True, @@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""", + expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""", ), ) @@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) From 9cdfd7403bc46f75d12be16ba6041b8bcdd3f7fd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 5 Mar 2026 17:12:38 +0200 Subject: [PATCH 36/75] feat(api-nodes): enable Kling 3.0 Motion Control (#12785) --- comfy_api_nodes/apis/kling.py | 1 + comfy_api_nodes/nodes_kling.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/comfy_api_nodes/apis/kling.py b/comfy_api_nodes/apis/kling.py index a5bd5f1d3..fe0f97cb3 100644 --- a/comfy_api_nodes/apis/kling.py +++ b/comfy_api_nodes/apis/kling.py @@ -148,3 +148,4 @@ class MotionControlRequest(BaseModel): keep_original_sound: str = Field(...) character_orientation: str = Field(...) mode: str = Field(..., description="'pro' or 'std'") + model_name: str = Field(...) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 74fa078ff..8963c335d 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -2747,6 +2747,7 @@ class MotionControl(IO.ComfyNode): "but the character orientation matches the reference image (camera/other details via prompt).", ), IO.Combo.Input("mode", options=["pro", "std"]), + IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True), ], outputs=[ IO.Video.Output(), @@ -2777,6 +2778,7 @@ class MotionControl(IO.ComfyNode): keep_original_sound: bool, character_orientation: str, mode: str, + model: str = "kling-v2-6", ) -> IO.NodeOutput: validate_string(prompt, max_length=2500) validate_image_dimensions(reference_image, min_width=340, min_height=340) @@ -2797,6 +2799,7 @@ class MotionControl(IO.ComfyNode): keep_original_sound="yes" if keep_original_sound else "no", character_orientation=character_orientation, mode=mode, + model_name=model, ), ) if response.code: From da29b797ce00b491c269e864cc3b8fceb279e530 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 23:23:23 +0800 Subject: [PATCH 37/75] Update workflow templates to v0.9.8 (#12788) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 866818e08..3fd44e0cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.7 +comfyui-workflow-templates==0.9.8 comfyui-embedded-docs==0.4.3 torch torchsde From 6ef82a89b83a49247081dc57b154172573c9e313 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 10:38:33 -0500 Subject: [PATCH 38/75] ComfyUI v0.16.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 0aea18d3a..e58e0fb63 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.0" +__version__ = "0.16.1" diff --git a/pyproject.toml b/pyproject.toml index f2133d99c..199a90364 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.0" +version = "0.16.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 6481569ad4c3606bc50e9de39ce810651690ae79 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 5 Mar 2026 09:04:24 -0800 Subject: [PATCH 39/75] comfy-aimdo 0.2.7 (#12791) Comfy-aimdo 0.2.7 fixes a crash when a spurious cudaAsyncFree comes in and would cause an infinite stack overflow (via detours hooks). A lock is also introduced on the link list holding the free sections to avoid any possibility of threaded miscellaneous cuda allocations being the root cause. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3fd44e0cf..f7098b730 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.6 +comfy-aimdo>=0.2.7 requests #non essential dependencies: From 42e0e023eee6a19c1adb7bd3dc11c81ff6dcc9c8 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:22:17 -0800 Subject: [PATCH 40/75] ops: Handle CPU weight in VBAR caster (#12792) This shouldn't happen but custom nodes gets there. Handle it as best we can. --- comfy/ops.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index 3e19cd1b6..06aa41d4f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): + + #vbar doesn't support CPU weights, but some custom nodes have weird paths + #that might switch the layer to the CPU and expect it to work. We have to take + #a clone conservatively as we are mmapped and some SFT files are packed misaligned + #If you are a custom node author reading this, please move your layer to the GPU + #or declare your ModelPatcher as CPU in the first place. + if device is not None and device.type == "cpu": + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = None + if s.bias is not None: + bias = s.bias.to(dtype=bias_dtype, copy=True) + return weight, bias, (None, None, None) + offload_stream = None xfer_dest = None From 5073da57ad20a2abb921f79458e49a7f7d608740 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 6 Mar 2026 02:22:38 +0800 Subject: [PATCH 41/75] chore: update workflow templates to v0.9.10 (#12793) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f7098b730..9a674fac5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.8 +comfyui-workflow-templates==0.9.10 comfyui-embedded-docs==0.4.3 torch torchsde From 1c3b651c0a1539a374e3d29a3ce695b5844ac5fc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:35:56 -0800 Subject: [PATCH 42/75] Refactor. (#12794) --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 06aa41d4f..87b36b5c5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -86,7 +86,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu #a clone conservatively as we are mmapped and some SFT files are packed misaligned #If you are a custom node author reading this, please move your layer to the GPU #or declare your ModelPatcher as CPU in the first place. - if device is not None and device.type == "cpu": + if comfy.model_management.is_device_cpu(device): weight = s.weight.to(dtype=dtype, copy=True) if isinstance(weight, QuantizedTensor): weight = weight.dequantize() From 50549aa252903b936b2ed00b5de418c8b47f0841 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 13:41:06 -0500 Subject: [PATCH 43/75] ComfyUI v0.16.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index e58e0fb63..bc49f2218 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.1" +__version__ = "0.16.2" diff --git a/pyproject.toml b/pyproject.toml index 199a90364..73bfd1007 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.1" +version = "0.16.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 8befce5c7b84ff3451a6bd3bcbae1355ad322855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 5 Mar 2026 22:37:25 +0200 Subject: [PATCH 44/75] Add manual cast to LTX2 vocoder conv_transpose1d (#12795) * Add manual cast to LTX2 vocoder * Update vocoder.py --- comfy/ldm/lightricks/vocoders/vocoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index 6c4028aa8..a0e03cada 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F import torch.nn as nn import comfy.ops +import comfy.model_management import numpy as np import math @@ -125,7 +126,7 @@ class UpSample1d(nn.Module): _, C, _ = x.shape x = F.pad(x, (self.pad, self.pad), mode="replicate") x = self.ratio * F.conv_transpose1d( - x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C ) x = x[..., self.pad_left : -self.pad_right] return x From 17b43c2b87eba43f0f071471b855e0ed659a2627 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Mar 2026 13:31:28 -0800 Subject: [PATCH 45/75] LTX audio vae novram fixes. (#12796) --- comfy/ldm/lightricks/vocoders/vocoder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index a0e03cada..2481d8bdd 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -82,7 +82,7 @@ class LowPassFilter1d(nn.Module): _, C, _ = x.shape if self.padding: x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) - return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) class UpSample1d(nn.Module): @@ -191,7 +191,7 @@ class Snake(nn.Module): self.eps = 1e-9 def forward(self, x): - a = self.alpha.unsqueeze(0).unsqueeze(-1) + a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) if self.alpha_logscale: a = torch.exp(a) return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) @@ -218,8 +218,8 @@ class SnakeBeta(nn.Module): self.eps = 1e-9 def forward(self, x): - a = self.alpha.unsqueeze(0).unsqueeze(-1) - b = self.beta.unsqueeze(0).unsqueeze(-1) + a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) + b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) if self.alpha_logscale: a = torch.exp(a) b = torch.exp(b) @@ -597,7 +597,7 @@ class _STFTFn(nn.Module): y = y.unsqueeze(1) # (B, 1, T) left_pad = max(0, self.win_length - self.hop_length) # causal: left-only y = F.pad(y, (left_pad, 0)) - spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) + spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0) n_freqs = spec.shape[1] // 2 real, imag = spec[:, :n_freqs], spec[:, n_freqs:] magnitude = torch.sqrt(real ** 2 + imag ** 2) @@ -648,7 +648,7 @@ class MelSTFT(nn.Module): """ magnitude, phase = self.stft_fn(y) energy = torch.norm(magnitude, dim=1) - mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude) log_mel = torch.log(torch.clamp(mel, min=1e-5)) return log_mel, magnitude, phase, energy From 58017e8726bdddae89704b1e0123bedc29994424 Mon Sep 17 00:00:00 2001 From: Tavi Halperin Date: Thu, 5 Mar 2026 23:51:20 +0200 Subject: [PATCH 46/75] feat: add causal_fix parameter to add_keyframe_index and append_keyframe (#12797) Allows explicit control over the causal_fix flag passed to latent_to_pixel_coords. Defaults to frame_idx == 0 when not specified, fixing the previous heuristic. --- comfy_extras/nodes_lt.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 32fe921ff..c05571143 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -253,10 +253,12 @@ class LTXVAddGuide(io.ComfyNode): return frame_idx, latent_idx @classmethod - def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1): + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1, causal_fix=None): keyframe_idxs, _ = get_keyframe_idxs(cond) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) - pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 + if causal_fix is None: + causal_fix = frame_idx == 0 or guiding_latent.shape[2] == 1 + pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=causal_fix) pixel_coords[:, 0] += frame_idx # The following adjusts keyframe end positions for small grid IC-LoRA. @@ -278,12 +280,12 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1): + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1, causal_fix=None): if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: raise ValueError("Adding guide to a combined AV latent is not supported.") - positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) - negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix) if guide_mask is not None: target_h = max(noise_mask.shape[3], guide_mask.shape[3]) From 1c218282369a6cc80651d878fc51fa33d7bf34e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 17:25:49 -0500 Subject: [PATCH 47/75] ComfyUI v0.16.3 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index bc49f2218..5da21150b 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.2" +__version__ = "0.16.3" diff --git a/pyproject.toml b/pyproject.toml index 73bfd1007..6a83c5c63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.2" +version = "0.16.3" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From e544c65db91df5a070be69a0a9b922201fe79335 Mon Sep 17 00:00:00 2001 From: Dante Date: Fri, 6 Mar 2026 11:51:28 +0900 Subject: [PATCH 48/75] feat: add Math Expression node with simpleeval evaluation (#12687) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add EagerEval dataclass for frontend-side node evaluation Add EagerEval to the V3 API schema, enabling nodes to declare frontend-evaluated JSONata expressions. The frontend uses this to display computation results as badges without a backend round-trip. Co-Authored-By: Claude Opus 4.6 * feat: add Math Expression node with JSONata evaluation Add ComfyMathExpression node that evaluates JSONata expressions against dynamically-grown numeric inputs using Autogrow + MatchType. Sends input context via ui output so the frontend can re-evaluate when the expression changes without a backend round-trip. Co-Authored-By: Claude Opus 4.6 * feat: register nodes_math.py in extras_files loader list Co-Authored-By: Claude Opus 4.6 * fix: address CodeRabbit review feedback - Harden EagerEval.validate with type checks and strip() for empty strings - Add _positional_alias for spreadsheet-style names beyond z (aa, ab...) - Validate JSONata result is numeric before returning - Add jsonata to requirements.txt Co-Authored-By: Claude Opus 4.6 * refactor: remove EagerEval, scope PR to math node only Remove EagerEval dataclass from _io.py and eager_eval usage from nodes_math.py. Eager execution will be designed as a general-purpose system in a separate effort. Co-Authored-By: Claude Opus 4.6 * fix: use TemplateNames, cap inputs at 26, improve error message Address Kosinkadink review feedback: - Switch from Autogrow.TemplatePrefix to Autogrow.TemplateNames so input slots are named a-z, matching expression variables directly - Cap max inputs at 26 (a-z) instead of 100 - Simplify execute() by removing dual-mapping hack - Include expression and result value in error message Co-Authored-By: Claude Opus 4.6 * test: add unit tests for Math Expression node Add tests for _positional_alias (a-z mapping) and execute() covering arithmetic operations, float inputs, $sum(values), and error cases. Co-Authored-By: Claude Opus 4.6 * refactor: replace jsonata with simpleeval for math evaluation jsonata PyPI package has critical issues: no Python 3.12/3.13 wheels, no ARM/Apple Silicon wheels, abandoned (last commit 2023), C extension. Replace with simpleeval (pure Python, 3.4M downloads/month, MIT, AST-based security). Add math module functions (sqrt, ceil, floor, log, sin, cos, tan) and variadic sum() supporting both sum(values) and sum(a, b, c). Pin version to >=1.0,<2.0. Co-Authored-By: Claude Opus 4.6 * test: update tests for simpleeval migration Update JSONata syntax to Python syntax ($sum -> sum, $string -> str), add tests for math functions (sqrt, ceil, floor, sin, log10) and variadic sum(a, b, c). Co-Authored-By: Claude Opus 4.6 * refactor: replace MatchType with MultiType inputs and dual FLOAT/INT outputs Allow mixing INT and FLOAT connections on the same node by switching from MatchType (which forces all inputs to the same type) to MultiType. Output both FLOAT and INT so users can pick the type they need. Co-Authored-By: Claude Opus 4.6 * test: update tests for mixed INT/FLOAT inputs and dual outputs Add assertions for both FLOAT (result[0]) and INT (result[1]) outputs. Add test_mixed_int_float_inputs and test_mixed_resolution_scale to verify the primary use case of multiplying resolutions by a float factor. Co-Authored-By: Claude Opus 4.6 * feat: make expression input multiline and validate empty expression - Add multiline=True to expression input for better UX with longer expressions - Add empty expression validation with clear "Expression cannot be empty." message Co-Authored-By: Claude Opus 4.6 * test: add tests for empty expression validation Co-Authored-By: Claude Opus 4.6 * fix: address review feedback — safe pow, isfinite guard, test coverage - Wrap pow() with _safe_pow to prevent DoS via huge exponents (pow() bypasses simpleeval's safe_power guard on **) - Add math.isfinite() check to catch inf/nan before int() conversion - Add int/float converters to MATH_FUNCTIONS for explicit casting - Add "calculator" search alias - Replace _positional_alias helper with string.ascii_lowercase - Narrow test assertions and add error path + function coverage tests Co-Authored-By: Claude Opus 4.6 * Update requirements.txt --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Jedrzej Kosinski Co-authored-by: Christian Byrne --- comfy_extras/nodes_math.py | 119 +++++++++++ nodes.py | 1 + requirements.txt | 1 + .../comfy_extras_test/nodes_math_test.py | 197 ++++++++++++++++++ 4 files changed, 318 insertions(+) create mode 100644 comfy_extras/nodes_math.py create mode 100644 tests-unit/comfy_extras_test/nodes_math_test.py diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py new file mode 100644 index 000000000..6417bacf1 --- /dev/null +++ b/comfy_extras/nodes_math.py @@ -0,0 +1,119 @@ +"""Math expression node using simpleeval for safe evaluation. + +Provides a ComfyMathExpression node that evaluates math expressions +against dynamically-grown numeric inputs. +""" + +from __future__ import annotations + +import math +import string + +from simpleeval import simple_eval +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +MAX_EXPONENT = 4000 + + +def _variadic_sum(*args): + """Support both sum(values) and sum(a, b, c).""" + if len(args) == 1 and hasattr(args[0], "__iter__"): + return sum(args[0]) + return sum(args) + + +def _safe_pow(base, exp): + """Wrap pow() with an exponent cap to prevent DoS via huge exponents. + + The ** operator is already guarded by simpleeval's safe_power, but + pow() as a callable bypasses that guard. + """ + if abs(exp) > MAX_EXPONENT: + raise ValueError(f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})") + return pow(base, exp) + + +MATH_FUNCTIONS = { + "sum": _variadic_sum, + "min": min, + "max": max, + "abs": abs, + "round": round, + "pow": _safe_pow, + "sqrt": math.sqrt, + "ceil": math.ceil, + "floor": math.floor, + "log": math.log, + "log2": math.log2, + "log10": math.log10, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "int": int, + "float": float, +} + + +class MathExpressionNode(io.ComfyNode): + """Evaluates a math expression against dynamically-grown inputs.""" + + @classmethod + def define_schema(cls) -> io.Schema: + autogrow = io.Autogrow.TemplateNames( + input=io.MultiType.Input("value", [io.Float, io.Int]), + names=list(string.ascii_lowercase), + min=1, + ) + return io.Schema( + node_id="ComfyMathExpression", + display_name="Math Expression", + category="math", + search_aliases=[ + "expression", "formula", "calculate", "calculator", + "eval", "math", + ], + inputs=[ + io.String.Input("expression", default="a + b", multiline=True), + io.Autogrow.Input("values", template=autogrow), + ], + outputs=[ + io.Float.Output(display_name="FLOAT"), + io.Int.Output(display_name="INT"), + ], + ) + + @classmethod + def execute( + cls, expression: str, values: io.Autogrow.Type + ) -> io.NodeOutput: + if not expression.strip(): + raise ValueError("Expression cannot be empty.") + + context: dict = dict(values) + context["values"] = list(values.values()) + + result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS) + # bool check must come first because bool is a subclass of int in Python + if isinstance(result, bool) or not isinstance(result, (int, float)): + raise ValueError( + f"Math Expression '{expression}' must evaluate to a numeric result, " + f"got {type(result).__name__}: {result!r}" + ) + if not math.isfinite(result): + raise ValueError( + f"Math Expression '{expression}' produced a non-finite result: {result}" + ) + return io.NodeOutput(float(result), int(result)) + + +class MathExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [MathExpressionNode] + + +async def comfy_entrypoint() -> MathExtension: + return MathExtension() diff --git a/nodes.py b/nodes.py index 5be9b16f9..0ef23b640 100644 --- a/nodes.py +++ b/nodes.py @@ -2449,6 +2449,7 @@ async def init_builtin_extra_nodes(): "nodes_replacements.py", "nodes_nag.py", "nodes_sdpose.py", + "nodes_math.py", ] import_failed = [] diff --git a/requirements.txt b/requirements.txt index 9a674fac5..7bf12247c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,7 @@ av>=14.2.0 comfy-kitchen>=0.2.7 comfy-aimdo>=0.2.7 requests +simpleeval>=1.0 #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/comfy_extras_test/nodes_math_test.py b/tests-unit/comfy_extras_test/nodes_math_test.py new file mode 100644 index 000000000..fa4cdcac3 --- /dev/null +++ b/tests-unit/comfy_extras_test/nodes_math_test.py @@ -0,0 +1,197 @@ +import math + +import pytest +from collections import OrderedDict +from unittest.mock import patch, MagicMock + +mock_nodes = MagicMock() +mock_nodes.MAX_RESOLUTION = 16384 +mock_server = MagicMock() + +with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}): + from comfy_extras.nodes_math import MathExpressionNode + + +class TestMathExpressionExecute: + @staticmethod + def _exec(expression: str, **kwargs) -> object: + values = OrderedDict(kwargs) + return MathExpressionNode.execute(expression, values) + + def test_addition(self): + result = self._exec("a + b", a=3, b=4) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_subtraction(self): + result = self._exec("a - b", a=10, b=3) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_multiplication(self): + result = self._exec("a * b", a=3, b=5) + assert result[0] == 15.0 + assert result[1] == 15 + + def test_division(self): + result = self._exec("a / b", a=10, b=4) + assert result[0] == 2.5 + assert result[1] == 2 + + def test_single_input(self): + result = self._exec("a * 2", a=5) + assert result[0] == 10.0 + assert result[1] == 10 + + def test_three_inputs(self): + result = self._exec("a + b + c", a=1, b=2, c=3) + assert result[0] == 6.0 + assert result[1] == 6 + + def test_float_inputs(self): + result = self._exec("a + b", a=1.5, b=2.5) + assert result[0] == 4.0 + assert result[1] == 4 + + def test_mixed_int_float_inputs(self): + result = self._exec("a * b", a=1024, b=1.5) + assert result[0] == 1536.0 + assert result[1] == 1536 + + def test_mixed_resolution_scale(self): + result = self._exec("a * b", a=512, b=0.75) + assert result[0] == 384.0 + assert result[1] == 384 + + def test_sum_values_array(self): + result = self._exec("sum(values)", a=1, b=2, c=3) + assert result[0] == 6.0 + + def test_sum_variadic(self): + result = self._exec("sum(a, b, c)", a=1, b=2, c=3) + assert result[0] == 6.0 + + def test_min_values(self): + result = self._exec("min(values)", a=5, b=2, c=8) + assert result[0] == 2.0 + + def test_max_values(self): + result = self._exec("max(values)", a=5, b=2, c=8) + assert result[0] == 8.0 + + def test_abs_function(self): + result = self._exec("abs(a)", a=-7) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_sqrt(self): + result = self._exec("sqrt(a)", a=16) + assert result[0] == 4.0 + assert result[1] == 4 + + def test_ceil(self): + result = self._exec("ceil(a)", a=2.3) + assert result[0] == 3.0 + assert result[1] == 3 + + def test_floor(self): + result = self._exec("floor(a)", a=2.7) + assert result[0] == 2.0 + assert result[1] == 2 + + def test_sin(self): + result = self._exec("sin(a)", a=0) + assert result[0] == 0.0 + + def test_log10(self): + result = self._exec("log10(a)", a=100) + assert result[0] == 2.0 + assert result[1] == 2 + + def test_float_output_type(self): + result = self._exec("a + b", a=1, b=2) + assert isinstance(result[0], float) + + def test_int_output_type(self): + result = self._exec("a + b", a=1, b=2) + assert isinstance(result[1], int) + + def test_non_numeric_result_raises(self): + with pytest.raises(ValueError, match="must evaluate to a numeric result"): + self._exec("'hello'", a=42) + + def test_undefined_function_raises(self): + with pytest.raises(Exception, match="not defined"): + self._exec("str(a)", a=42) + + def test_boolean_result_raises(self): + with pytest.raises(ValueError, match="got bool"): + self._exec("a > b", a=5, b=3) + + def test_empty_expression_raises(self): + with pytest.raises(ValueError, match="Expression cannot be empty"): + self._exec("", a=1) + + def test_whitespace_only_expression_raises(self): + with pytest.raises(ValueError, match="Expression cannot be empty"): + self._exec(" ", a=1) + + # --- Missing function coverage (round, pow, log, log2, cos, tan) --- + + def test_round(self): + result = self._exec("round(a)", a=2.7) + assert result[0] == 3.0 + assert result[1] == 3 + + def test_round_with_ndigits(self): + result = self._exec("round(a, 2)", a=3.14159) + assert result[0] == pytest.approx(3.14) + + def test_pow(self): + result = self._exec("pow(a, b)", a=2, b=10) + assert result[0] == 1024.0 + assert result[1] == 1024 + + def test_log(self): + result = self._exec("log(a)", a=math.e) + assert result[0] == pytest.approx(1.0) + + def test_log2(self): + result = self._exec("log2(a)", a=8) + assert result[0] == pytest.approx(3.0) + + def test_cos(self): + result = self._exec("cos(a)", a=0) + assert result[0] == 1.0 + + def test_tan(self): + result = self._exec("tan(a)", a=0) + assert result[0] == 0.0 + + # --- int/float converter functions --- + + def test_int_converter(self): + result = self._exec("int(a / b)", a=7, b=2) + assert result[1] == 3 + + def test_float_converter(self): + result = self._exec("float(a)", a=5) + assert result[0] == 5.0 + + # --- Error path tests --- + + def test_division_by_zero_raises(self): + with pytest.raises(ZeroDivisionError): + self._exec("a / b", a=1, b=0) + + def test_sqrt_negative_raises(self): + with pytest.raises(ValueError, match="math domain error"): + self._exec("sqrt(a)", a=-1) + + def test_overflow_inf_raises(self): + with pytest.raises(ValueError, match="non-finite result"): + self._exec("a * b", a=1e308, b=10) + + def test_pow_huge_exponent_raises(self): + with pytest.raises(ValueError, match="Exponent .* exceeds maximum"): + self._exec("pow(a, b)", a=10, b=10000000) From 3b93d5d571cb3e018da65f822cd11b60202b11c2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 6 Mar 2026 11:04:48 +0200 Subject: [PATCH 49/75] feat(api-nodes): add TencentSmartTopology node (#12741) * feat(api-nodes): add TencentSmartTopology node * feat(api-nodes): enable TencentModelTo3DUV node * chore(Tencent endpoints): add "wait" to queued statuses --- comfy_api_nodes/apis/hunyuan3d.py | 16 ++++- comfy_api_nodes/nodes_hunyuan3d.py | 109 ++++++++++++++++++++++++++--- comfy_api_nodes/util/client.py | 2 +- 3 files changed, 114 insertions(+), 13 deletions(-) diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py index e84eba31e..dad9bc2fa 100644 --- a/comfy_api_nodes/apis/hunyuan3d.py +++ b/comfy_api_nodes/apis/hunyuan3d.py @@ -66,13 +66,17 @@ class To3DProTaskQueryRequest(BaseModel): JobId: str = Field(...) -class To3DUVFileInput(BaseModel): +class TaskFile3DInput(BaseModel): Type: str = Field(..., description="File type: GLB, OBJ, or FBX") Url: str = Field(...) class To3DUVTaskRequest(BaseModel): - File: To3DUVFileInput = Field(...) + File: TaskFile3DInput = Field(...) + + +class To3DPartTaskRequest(BaseModel): + File: TaskFile3DInput = Field(...) class TextureEditImageInfo(BaseModel): @@ -80,7 +84,13 @@ class TextureEditImageInfo(BaseModel): class TextureEditTaskRequest(BaseModel): - File3D: To3DUVFileInput = Field(...) + File3D: TaskFile3DInput = Field(...) Image: TextureEditImageInfo | None = Field(None) Prompt: str | None = Field(None) EnablePBR: bool | None = Field(None) + + +class SmartTopologyRequest(BaseModel): + File3D: TaskFile3DInput = Field(...) + PolygonType: str | None = Field(...) + FaceLevel: str | None = Field(...) diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index d1d9578ec..bd8bde997 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -5,18 +5,19 @@ from comfy_api_nodes.apis.hunyuan3d import ( Hunyuan3DViewImage, InputGenerateType, ResultFile3D, + SmartTopologyRequest, + TaskFile3DInput, TextureEditTaskRequest, + To3DPartTaskRequest, To3DProTaskCreateResponse, To3DProTaskQueryRequest, To3DProTaskRequest, To3DProTaskResultResponse, - To3DUVFileInput, To3DUVTaskRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_file_3d, - download_url_to_image_tensor, downscale_image_tensor_by_max_side, poll_op, sync_op, @@ -344,7 +345,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode): outputs=[ IO.File3DOBJ.Output(display_name="OBJ"), IO.File3DFBX.Output(display_name="FBX"), - IO.Image.Output(), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -375,7 +375,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode): ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"), response_model=To3DProTaskCreateResponse, data=To3DUVTaskRequest( - File=To3DUVFileInput( + File=TaskFile3DInput( Type=file_format.upper(), Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format), ) @@ -394,7 +394,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode): return IO.NodeOutput( await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), - await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url), ) @@ -463,7 +462,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode): ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"), response_model=To3DProTaskCreateResponse, data=TextureEditTaskRequest( - File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url), + File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), Prompt=prompt, EnablePBR=True, ), @@ -538,8 +537,8 @@ class Tencent3DPartNode(IO.ComfyNode): cls, ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"), response_model=To3DProTaskCreateResponse, - data=To3DUVTaskRequest( - File=To3DUVFileInput(Type=file_format.upper(), Url=model_url), + data=To3DPartTaskRequest( + File=TaskFile3DInput(Type=file_format.upper(), Url=model_url), ), is_rate_limited=_is_tencent_rate_limited, ) @@ -557,15 +556,107 @@ class Tencent3DPartNode(IO.ComfyNode): ) +class TencentSmartTopologyNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentSmartTopologyNode", + display_name="Hunyuan3D: Smart Topology", + category="api node/3d/Tencent", + description="Perform smart retopology on a 3D model. " + "Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DAny], + tooltip="Input 3D model (GLB or OBJ)", + ), + IO.Combo.Input( + "polygon_type", + options=["triangle", "quadrilateral"], + tooltip="Surface composition type.", + ), + IO.Combo.Input( + "face_level", + options=["medium", "high", "low"], + tooltip="Polygon reduction level.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DOBJ.Output(display_name="OBJ"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge(expr='{"type":"usd","usd":1.0}'), + ) + + SUPPORTED_FORMATS = {"glb", "obj"} + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + polygon_type: str, + face_level: str, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format not in cls.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported file format: '{file_format}'. " f"Supported: {', '.join(sorted(cls.SUPPORTED_FORMATS))}." + ) + model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology", method="POST"), + response_model=To3DProTaskCreateResponse, + data=SmartTopologyRequest( + File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), + PolygonType=polygon_type, + FaceLevel=face_level, + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed: [{response.Error.Code}] {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), + ) + + class TencentHunyuan3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ TencentTextToModelNode, TencentImageToModelNode, - # TencentModelTo3DUVNode, + TencentModelTo3DUVNode, # Tencent3DTextureEditNode, Tencent3DPartNode, + TencentSmartTopologyNode, ] diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 94886af7b..79ffb77c1 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -83,7 +83,7 @@ class _PollUIState: _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] -QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] async def sync_op( From 34e55f006156801a6b5988d046d9041cb681f12d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 6 Mar 2026 19:54:27 +0200 Subject: [PATCH 50/75] feat(api-nodes): add Gemini 3.1 Flash Lite model to LLM node (#12803) --- comfy_api_nodes/nodes_gemini.py | 47 +++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index d83d2fc15..8225ea67e 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -72,18 +72,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( ) -class GeminiModel(str, Enum): - """ - Gemini Model Names allowed by comfy-api - """ - - gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06" - gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" - gemini_2_5_pro = "gemini-2.5-pro" - gemini_2_5_flash = "gemini-2.5-flash" - gemini_3_0_pro = "gemini-3-pro-preview" - - class GeminiImageModel(str, Enum): """ Gemini Image Model Names allowed by comfy-api @@ -237,10 +225,14 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N input_tokens_price = 0.30 output_text_tokens_price = 2.50 output_image_tokens_price = 30.0 - elif response.modelVersion == "gemini-3-pro-preview": + elif response.modelVersion in ("gemini-3-pro-preview", "gemini-3.1-pro-preview"): input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 0.0 + elif response.modelVersion == "gemini-3.1-flash-lite-preview": + input_tokens_price = 0.25 + output_text_tokens_price = 1.50 + output_image_tokens_price = 0.0 elif response.modelVersion == "gemini-3-pro-image-preview": input_tokens_price = 2 output_text_tokens_price = 12.0 @@ -292,8 +284,16 @@ class GeminiNode(IO.ComfyNode): ), IO.Combo.Input( "model", - options=GeminiModel, - default=GeminiModel.gemini_2_5_pro, + options=[ + "gemini-2.5-pro-preview-05-06", + "gemini-2.5-flash-preview-04-17", + "gemini-2.5-pro", + "gemini-2.5-flash", + "gemini-3-pro-preview", + "gemini-3-1-pro", + "gemini-3-1-flash-lite", + ], + default="gemini-3-1-pro", tooltip="The Gemini model to use for generating responses.", ), IO.Int.Input( @@ -363,11 +363,16 @@ class GeminiNode(IO.ComfyNode): "usd": [0.00125, 0.01], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } - : $contains($m, "gemini-3-pro-preview") ? { + : ($contains($m, "gemini-3-pro-preview") or $contains($m, "gemini-3-1-pro")) ? { "type": "list_usd", "usd": [0.002, 0.012], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } + : $contains($m, "gemini-3-1-flash-lite") ? { + "type": "list_usd", + "usd": [0.00025, 0.0015], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } : {"type":"text", "text":"Token-based"} ) """, @@ -436,12 +441,14 @@ class GeminiNode(IO.ComfyNode): files: list[GeminiPart] | None = None, system_prompt: str = "", ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False) + if model == "gemini-3-pro-preview": + model = "gemini-3.1-pro-preview" # model "gemini-3-pro-preview" will be soon deprecated by Google + elif model == "gemini-3-1-pro": + model = "gemini-3.1-pro-preview" + elif model == "gemini-3-1-flash-lite": + model = "gemini-3.1-flash-lite-preview" - # Create parts list with text prompt as the first part parts: list[GeminiPart] = [GeminiPart(text=prompt)] - - # Add other modal parts if images is not None: parts.extend(await create_image_parts(cls, images)) if audio is not None: From f466b066017b9ebe5df67decfcbd09f78c5c66fa Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:20:07 -0800 Subject: [PATCH 51/75] Fix fp16 audio encoder models (#12811) * mp: respect model_defined_dtypes in default caster This is needed for parametrizations when the dtype changes between sd and model. * audio_encoders: archive model dtypes Archive model dtypes to stop the state dict load override the dtypes defined by the core for compute etc. --- comfy/audio_encoders/audio_encoders.py | 1 + comfy/model_patcher.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 16998af94..0de7584b0 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -27,6 +27,7 @@ class AudioEncoderModel(): self.model.eval() self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 + comfy.model_management.archive_model_dtypes(self.model) def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7e5ad7aa4..745384271 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -715,8 +715,8 @@ class ModelPatcher: default = True # default random weights in non leaf modules break if default and default_device is not None: - for param in params.values(): - param.data = param.data.to(device=default_device) + for param_name, param in params.items(): + param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None)) if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0): module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem From d69d30819b91aa020d0bb888df2a5b917f83bb7e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 6 Mar 2026 16:11:16 -0800 Subject: [PATCH 52/75] Don't run TE on cpu when dynamic vram enabled. (#12815) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index ee28ea107..39b4aa483 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -939,7 +939,7 @@ def text_encoder_offload_device(): def text_encoder_device(): if args.gpu_only: return get_torch_device() - elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: + elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled: if should_use_fp16(prioritize_performance=False): return get_torch_device() else: From afc00f00553885eeb96ded329878fe732f6b9f7a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 6 Mar 2026 17:10:53 -0800 Subject: [PATCH 53/75] Fix requirements version. (#12817) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7bf12247c..26e2ecdec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,7 @@ av>=14.2.0 comfy-kitchen>=0.2.7 comfy-aimdo>=0.2.7 requests -simpleeval>=1.0 +simpleeval>=1.0.0 #non essential dependencies: kornia>=0.7.1 From 6ac8152fc80734b084d12865460e5e9a5d9a4e1b Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 7 Mar 2026 15:54:09 +0800 Subject: [PATCH 54/75] chore: update workflow templates to v0.9.11 (#12821) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 26e2ecdec..dc9a9ded0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.10 +comfyui-workflow-templates==0.9.11 comfyui-embedded-docs==0.4.3 torch torchsde From bcf1a1fab1e9efe0d4999ea14e9c0318409e0000 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 7 Mar 2026 09:38:08 -0800 Subject: [PATCH 55/75] mm: reset_cast_buffers: sync compute stream before free (#12822) Sync the compute stream before freeing the cast buffers. This can cause use after free issues when the cast stream frees the buffer while the compute stream is behind enough to still needs a casted weight. --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 39b4aa483..07bc8ad67 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1148,6 +1148,7 @@ def reset_cast_buffers(): LARGEST_CASTED_WEIGHT = (None, 0) for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() + synchronize() STREAM_CAST_BUFFERS.clear() soft_empty_cache() From a7a6335be538f55faa2abf7404c9b8e970847d1f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 7 Mar 2026 16:52:39 -0500 Subject: [PATCH 56/75] ComfyUI v0.16.4 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 5da21150b..2723d02e7 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.3" +__version__ = "0.16.4" diff --git a/pyproject.toml b/pyproject.toml index 6a83c5c63..753b219b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.3" +version = "0.16.4" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 29b24cb5177e9d5aa5b3d2e5869999efb4d538c7 Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Sat, 7 Mar 2026 17:37:25 -0800 Subject: [PATCH 57/75] refactor(assets): modular architecture + async two-phase scanner & background seeder (#12621) --- .../0002_merge_to_asset_references.py | 267 +++++ app/assets/api/routes.py | 810 ++++++++----- app/assets/api/schemas_in.py | 79 +- app/assets/api/schemas_out.py | 6 +- app/assets/api/upload.py | 171 +++ app/assets/database/bulk_ops.py | 204 ---- app/assets/database/models.py | 210 ++-- app/assets/database/queries.py | 976 ---------------- app/assets/database/queries/__init__.py | 121 ++ app/assets/database/queries/asset.py | 140 +++ .../database/queries/asset_reference.py | 1033 +++++++++++++++++ app/assets/database/queries/common.py | 54 + app/assets/database/queries/tags.py | 356 ++++++ app/assets/database/tags.py | 62 - app/assets/hashing.py | 75 -- app/assets/helpers.py | 319 +---- app/assets/manager.py | 516 -------- app/assets/scanner.py | 768 ++++++++---- app/assets/seeder.py | 794 +++++++++++++ app/assets/services/__init__.py | 87 ++ app/assets/services/asset_management.py | 309 +++++ app/assets/services/bulk_ingest.py | 280 +++++ app/assets/services/file_utils.py | 70 ++ app/assets/services/hashing.py | 95 ++ app/assets/services/ingest.py | 375 ++++++ app/assets/services/metadata_extract.py | 327 ++++++ app/assets/services/path_utils.py | 167 +++ app/assets/services/schemas.py | 109 ++ app/assets/services/tagging.py | 75 ++ app/database/db.py | 81 +- comfy/cli_args.py | 2 +- comfy_api/feature_flags.py | 1 + main.py | 35 +- requirements.txt | 2 + server.py | 19 +- tests-unit/assets_test/conftest.py | 16 +- tests-unit/assets_test/helpers.py | 28 + tests-unit/assets_test/queries/conftest.py | 20 + tests-unit/assets_test/queries/test_asset.py | 144 +++ .../assets_test/queries/test_asset_info.py | 517 +++++++++ .../assets_test/queries/test_cache_state.py | 499 ++++++++ .../assets_test/queries/test_metadata.py | 184 +++ tests-unit/assets_test/queries/test_tags.py | 366 ++++++ tests-unit/assets_test/services/__init__.py | 1 + tests-unit/assets_test/services/conftest.py | 54 + .../services/test_asset_management.py | 268 +++++ .../assets_test/services/test_bulk_ingest.py | 137 +++ .../assets_test/services/test_enrich.py | 207 ++++ .../assets_test/services/test_ingest.py | 229 ++++ .../assets_test/services/test_tagging.py | 197 ++++ .../assets_test/test_assets_missing_sync.py | 2 +- tests-unit/assets_test/test_crud.py | 138 ++- tests-unit/assets_test/test_downloads.py | 4 +- tests-unit/assets_test/test_file_utils.py | 121 ++ tests-unit/assets_test/test_list_filter.py | 40 +- .../assets_test/test_prune_orphaned_assets.py | 2 +- .../assets_test/test_sync_references.py | 482 ++++++++ .../{test_tags.py => test_tags_api.py} | 4 +- tests-unit/assets_test/test_uploads.py | 22 +- tests-unit/requirements.txt | 1 - tests-unit/seeder_test/test_seeder.py | 900 ++++++++++++++ utils/mime_types.py | 37 + 62 files changed, 10737 insertions(+), 2878 deletions(-) create mode 100644 alembic_db/versions/0002_merge_to_asset_references.py create mode 100644 app/assets/api/upload.py delete mode 100644 app/assets/database/bulk_ops.py delete mode 100644 app/assets/database/queries.py create mode 100644 app/assets/database/queries/__init__.py create mode 100644 app/assets/database/queries/asset.py create mode 100644 app/assets/database/queries/asset_reference.py create mode 100644 app/assets/database/queries/common.py create mode 100644 app/assets/database/queries/tags.py delete mode 100644 app/assets/database/tags.py delete mode 100644 app/assets/hashing.py delete mode 100644 app/assets/manager.py create mode 100644 app/assets/seeder.py create mode 100644 app/assets/services/__init__.py create mode 100644 app/assets/services/asset_management.py create mode 100644 app/assets/services/bulk_ingest.py create mode 100644 app/assets/services/file_utils.py create mode 100644 app/assets/services/hashing.py create mode 100644 app/assets/services/ingest.py create mode 100644 app/assets/services/metadata_extract.py create mode 100644 app/assets/services/path_utils.py create mode 100644 app/assets/services/schemas.py create mode 100644 app/assets/services/tagging.py create mode 100644 tests-unit/assets_test/helpers.py create mode 100644 tests-unit/assets_test/queries/conftest.py create mode 100644 tests-unit/assets_test/queries/test_asset.py create mode 100644 tests-unit/assets_test/queries/test_asset_info.py create mode 100644 tests-unit/assets_test/queries/test_cache_state.py create mode 100644 tests-unit/assets_test/queries/test_metadata.py create mode 100644 tests-unit/assets_test/queries/test_tags.py create mode 100644 tests-unit/assets_test/services/__init__.py create mode 100644 tests-unit/assets_test/services/conftest.py create mode 100644 tests-unit/assets_test/services/test_asset_management.py create mode 100644 tests-unit/assets_test/services/test_bulk_ingest.py create mode 100644 tests-unit/assets_test/services/test_enrich.py create mode 100644 tests-unit/assets_test/services/test_ingest.py create mode 100644 tests-unit/assets_test/services/test_tagging.py create mode 100644 tests-unit/assets_test/test_file_utils.py create mode 100644 tests-unit/assets_test/test_sync_references.py rename tests-unit/assets_test/{test_tags.py => test_tags_api.py} (98%) create mode 100644 tests-unit/seeder_test/test_seeder.py create mode 100644 utils/mime_types.py diff --git a/alembic_db/versions/0002_merge_to_asset_references.py b/alembic_db/versions/0002_merge_to_asset_references.py new file mode 100644 index 000000000..1ac1b980c --- /dev/null +++ b/alembic_db/versions/0002_merge_to_asset_references.py @@ -0,0 +1,267 @@ +""" +Merge AssetInfo and AssetCacheState into unified asset_references table. + +This migration drops old tables and creates the new unified schema. +All existing data is discarded. + +Revision ID: 0002_merge_to_asset_references +Revises: 0001_assets +Create Date: 2025-02-11 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0002_merge_to_asset_references" +down_revision = "0001_assets" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Drop old tables (order matters due to FK constraints) + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") + + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + # Truncate assets table (cascades handled by dropping dependent tables first) + op.execute("DELETE FROM assets") + + # Create asset_references table + op.create_table( + "asset_references", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column( + "asset_id", + sa.String(length=36), + sa.ForeignKey("assets.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("file_path", sa.Text(), nullable=True), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column( + "needs_verify", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column( + "is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false") + ), + sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column( + "preview_id", + sa.String(length=36), + sa.ForeignKey("assets.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True), + sa.CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" + ), + sa.CheckConstraint( + "enrichment_level >= 0 AND enrichment_level <= 2", + name="ck_ar_enrichment_level_range", + ), + ) + op.create_index( + "uq_asset_references_file_path", "asset_references", ["file_path"], unique=True + ) + op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"]) + op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"]) + op.create_index("ix_asset_references_name", "asset_references", ["name"]) + op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"]) + op.create_index( + "ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"] + ) + op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"]) + op.create_index( + "ix_asset_references_last_access_time", "asset_references", ["last_access_time"] + ) + op.create_index( + "ix_asset_references_owner_name", "asset_references", ["owner_id", "name"] + ) + op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"]) + + # Create asset_reference_tags table + op.create_table( + "asset_reference_tags", + sa.Column( + "asset_reference_id", + sa.String(length=36), + sa.ForeignKey("asset_references.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "tag_name", + sa.String(length=512), + sa.ForeignKey("tags.name", ondelete="RESTRICT"), + nullable=False, + ), + sa.Column( + "origin", sa.String(length=32), nullable=False, server_default="manual" + ), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint( + "asset_reference_id", "tag_name", name="pk_asset_reference_tags" + ), + ) + op.create_index( + "ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"] + ) + op.create_index( + "ix_asset_reference_tags_asset_reference_id", + "asset_reference_tags", + ["asset_reference_id"], + ) + + # Create asset_reference_meta table + op.create_table( + "asset_reference_meta", + sa.Column( + "asset_reference_id", + sa.String(length=36), + sa.ForeignKey("asset_references.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint( + "asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta" + ), + ) + op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"]) + op.create_index( + "ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"] + ) + op.create_index( + "ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"] + ) + op.create_index( + "ix_asset_reference_meta_key_val_bool", + "asset_reference_meta", + ["key", "val_bool"], + ) + + +def downgrade() -> None: + """Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema. + + NOTE: Data is not recoverable. The upgrade discards all rows from the old + tables and truncates assets. After downgrade the old schema will be empty. + A filesystem rescan will repopulate data once the older code is running. + """ + # Drop new tables (order matters due to FK constraints) + op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta") + op.drop_table("asset_reference_meta") + + op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags") + op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags") + op.drop_table("asset_reference_tags") + + op.drop_index("ix_asset_references_deleted_at", table_name="asset_references") + op.drop_index("ix_asset_references_owner_name", table_name="asset_references") + op.drop_index("ix_asset_references_last_access_time", table_name="asset_references") + op.drop_index("ix_asset_references_created_at", table_name="asset_references") + op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references") + op.drop_index("ix_asset_references_is_missing", table_name="asset_references") + op.drop_index("ix_asset_references_name", table_name="asset_references") + op.drop_index("ix_asset_references_owner_id", table_name="asset_references") + op.drop_index("ix_asset_references_asset_id", table_name="asset_references") + op.drop_index("uq_asset_references_file_path", table_name="asset_references") + op.drop_table("asset_references") + + # Truncate assets (upgrade deleted all rows; downgrade starts fresh too) + op.execute("DELETE FROM assets") + + # Recreate old tables from 0001_assets schema + op.create_table( + "assets_info", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) + + op.create_table( + "asset_cache_state", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("file_path", sa.Text(), nullable=False), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) + + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 7676e50b4..40dee9f46 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -1,56 +1,144 @@ +import asyncio +import functools +import json import logging -import uuid -import urllib.parse import os -import contextlib -from aiohttp import web +import urllib.parse +import uuid +from typing import Any +from aiohttp import web from pydantic import ValidationError -import app.assets.manager as manager -from app import user_manager -from app.assets.api import schemas_in -from app.assets.helpers import get_query_dict -from app.assets.scanner import seed_assets - import folder_paths +from app import user_manager +from app.assets.api import schemas_in, schemas_out +from app.assets.api.schemas_in import ( + AssetValidationError, + UploadError, +) +from app.assets.helpers import validate_blake3_hash +from app.assets.api.upload import ( + delete_temp_file_if_exists, + parse_multipart_upload, +) +from app.assets.seeder import ScanInProgressError, asset_seeder +from app.assets.services import ( + DependencyMissingError, + HashMismatchError, + apply_tags, + asset_exists, + create_from_hash, + delete_asset_reference, + get_asset_detail, + list_assets_page, + list_tags, + remove_tags, + resolve_asset_for_download, + update_asset_metadata, + upload_from_temp_path, +) ROUTES = web.RouteTableDef() USER_MANAGER: user_manager.UserManager | None = None +_ASSETS_ENABLED = False + + +def _require_assets_feature_enabled(handler): + @functools.wraps(handler) + async def wrapper(request: web.Request) -> web.Response: + if not _ASSETS_ENABLED: + return _build_error_response( + 503, + "SERVICE_DISABLED", + "Assets system is disabled. Start the server with --enable-assets to use this feature.", + ) + return await handler(request) + + return wrapper + # UUID regex (canonical hyphenated form, case-insensitive) UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" -# Note to any custom node developers reading this code: -# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same. -def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: - global USER_MANAGER - USER_MANAGER = user_manager_instance +def get_query_dict(request: web.Request) -> dict[str, Any]: + """Gets a dictionary of query parameters from the request. + + request.query is a MultiMapping[str], needs to be converted to a dict + to be validated by Pydantic. + """ + query_dict = { + key: request.query.getall(key) + if len(request.query.getall(key)) > 1 + else request.query.get(key) + for key in request.query.keys() + } + return query_dict + + +# Note to any custom node developers reading this code: +# The assets system is not yet fully implemented, +# do not rely on the code in /app/assets remaining the same. + + +def register_assets_routes( + app: web.Application, + user_manager_instance: user_manager.UserManager | None = None, +) -> None: + global USER_MANAGER, _ASSETS_ENABLED + if user_manager_instance is not None: + USER_MANAGER = user_manager_instance + _ASSETS_ENABLED = True app.add_routes(ROUTES) -def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: - return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) + +def disable_assets_routes() -> None: + """Disable asset routes at runtime (e.g. after DB init failure).""" + global _ASSETS_ENABLED + _ASSETS_ENABLED = False -def _validation_error_response(code: str, ve: ValidationError) -> web.Response: - return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) +def _build_error_response( + status: int, code: str, message: str, details: dict | None = None +) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) + + +def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response: + errors = json.loads(ve.json()) + return _build_error_response(400, code, "Validation failed.", {"errors": errors}) + + +def _validate_sort_field(requested: str | None) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" @ROUTES.head("/api/assets/hash/{hash}") +@_require_assets_feature_enabled async def head_asset_by_hash(request: web.Request) -> web.Response: hash_str = request.match_info.get("hash", "").strip().lower() - if not hash_str or ":" not in hash_str: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - algo, digest = hash_str.split(":", 1) - if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - exists = manager.asset_exists(asset_hash=hash_str) + try: + hash_str = validate_blake3_hash(hash_str) + except ValueError: + return _build_error_response( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) + exists = asset_exists(hash_str) return web.Response(status=200 if exists else 404) @ROUTES.get("/api/assets") -async def list_assets(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def list_assets_route(request: web.Request) -> web.Response: """ GET request to list assets. """ @@ -58,78 +146,140 @@ async def list_assets(request: web.Request) -> web.Response: try: q = schemas_in.ListAssetsQuery.model_validate(query_dict) except ValidationError as ve: - return _validation_error_response("INVALID_QUERY", ve) + return _build_validation_error_response("INVALID_QUERY", ve) - payload = manager.list_assets( + sort = _validate_sort_field(q.sort) + order_candidate = (q.order or "desc").lower() + order = order_candidate if order_candidate in {"asc", "desc"} else "desc" + + result = list_assets_page( + owner_id=USER_MANAGER.get_request_user_id(request), include_tags=q.include_tags, exclude_tags=q.exclude_tags, name_contains=q.name_contains, metadata_filter=q.metadata_filter, limit=q.limit, offset=q.offset, - sort=q.sort, - order=q.order, - owner_id=USER_MANAGER.get_request_user_id(request), + sort=sort, + order=order, + ) + + summaries = [ + schemas_out.AssetSummary( + id=item.ref.id, + name=item.ref.name, + asset_hash=item.asset.hash if item.asset else None, + size=int(item.asset.size_bytes) if item.asset else None, + mime_type=item.asset.mime_type if item.asset else None, + tags=item.tags, + created_at=item.ref.created_at, + updated_at=item.ref.updated_at, + last_access_time=item.ref.last_access_time, + ) + for item in result.items + ] + + payload = schemas_out.AssetsList( + assets=summaries, + total=result.total, + has_more=(q.offset + len(summaries)) < result.total, ) return web.json_response(payload.model_dump(mode="json", exclude_none=True)) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") -async def get_asset(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def get_asset_route(request: web.Request) -> web.Response: """ GET request to get an asset's info as JSON. """ - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - result = manager.get_asset( - asset_info_id=asset_info_id, + result = get_asset_detail( + reference_id=reference_id, owner_id=USER_MANAGER.get_request_user_id(request), ) + if not result: + return _build_error_response( + 404, + "ASSET_NOT_FOUND", + f"AssetReference {reference_id} not found", + {"id": reference_id}, + ) + + payload = schemas_out.AssetDetail( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + size=int(result.asset.size_bytes) if result.asset else None, + mime_type=result.asset.mime_type if result.asset else None, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + ) except ValueError as e: - return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} + ) except Exception: logging.exception( - "get_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "get_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") +@_require_assets_feature_enabled async def download_asset_content(request: web.Request) -> web.Response: - # question: do we need disposition? could we just stick with one of these? disposition = request.query.get("disposition", "attachment").lower().strip() if disposition not in {"inline", "attachment"}: disposition = "attachment" try: - abs_path, content_type, filename = manager.resolve_asset_content_for_download( - asset_info_id=str(uuid.UUID(request.match_info["id"])), + result = resolve_asset_for_download( + reference_id=str(uuid.UUID(request.match_info["id"])), owner_id=USER_MANAGER.get_request_user_id(request), ) + abs_path = result.abs_path + content_type = result.content_type + filename = result.download_name except ValueError as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve)) + return _build_error_response(404, "ASSET_NOT_FOUND", str(ve)) except NotImplementedError as nie: - return _error_response(501, "BACKEND_UNSUPPORTED", str(nie)) + return _build_error_response(501, "BACKEND_UNSUPPORTED", str(nie)) except FileNotFoundError: - return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") + return _build_error_response( + 404, "FILE_NOT_FOUND", "Underlying file not found on disk." + ) - quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") - cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' + _DANGEROUS_MIME_TYPES = { + "text/html", "text/html-sandboxed", "application/xhtml+xml", + "text/javascript", "text/css", + } + if content_type in _DANGEROUS_MIME_TYPES: + content_type = "application/octet-stream" + + safe_name = (filename or "").replace("\r", "").replace("\n", "") + encoded = urllib.parse.quote(safe_name) + cd = f"{disposition}; filename*=UTF-8''{encoded}" file_size = os.path.getsize(abs_path) + size_mb = file_size / (1024 * 1024) logging.info( - "download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s", + "download_asset_content: path=%s, size=%d bytes (%.2f MB), type=%s, name=%s", abs_path, file_size, - file_size / (1024 * 1024), + size_mb, content_type, filename, ) - async def file_sender(): + async def stream_file_chunks(): chunk_size = 64 * 1024 with open(abs_path, "rb") as f: while True: @@ -139,26 +289,30 @@ async def download_asset_content(request: web.Request) -> web.Response: yield chunk return web.Response( - body=file_sender(), + body=stream_file_chunks(), content_type=content_type, headers={ "Content-Disposition": cd, "Content-Length": str(file_size), + "X-Content-Type-Options": "nosniff", }, ) @ROUTES.post("/api/assets/from-hash") -async def create_asset_from_hash(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def create_asset_from_hash_route(request: web.Request) -> web.Response: try: payload = await request.json() body = schemas_in.CreateFromHashBody.model_validate(payload) except ValidationError as ve: - return _validation_error_response("INVALID_BODY", ve) + return _build_validation_error_response("INVALID_BODY", ve) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) - result = manager.create_asset_from_hash( + result = create_from_hash( hash_str=body.hash, name=body.name, tags=body.tags, @@ -166,246 +320,209 @@ async def create_asset_from_hash(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), ) if result is None: - return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") - return web.json_response(result.model_dump(mode="json"), status=201) + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" + ) + + payload_out = schemas_out.AssetCreated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash, + size=int(result.asset.size_bytes), + mime_type=result.asset.mime_type, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + created_new=result.created_new, + ) + return web.json_response(payload_out.model_dump(mode="json"), status=201) @ROUTES.post("/api/assets") +@_require_assets_feature_enabled async def upload_asset(request: web.Request) -> web.Response: """Multipart/form-data endpoint for Asset uploads.""" - if not (request.content_type or "").lower().startswith("multipart/"): - return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") - - reader = await request.multipart() - - file_present = False - file_client_name: str | None = None - tags_raw: list[str] = [] - provided_name: str | None = None - user_metadata_raw: str | None = None - provided_hash: str | None = None - provided_hash_exists: bool | None = None - - file_written = 0 - tmp_path: str | None = None - while True: - field = await reader.next() - if field is None: - break - - fname = getattr(field, "name", "") or "" - - if fname == "hash": - try: - s = ((await field.text()) or "").strip().lower() - except Exception: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - - if s: - if ":" not in s: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - provided_hash = f"{algo}:{digest}" - try: - provided_hash_exists = manager.asset_exists(asset_hash=provided_hash) - except Exception: - provided_hash_exists = None # do not fail the whole request here - - elif fname == "file": - file_present = True - file_client_name = (field.filename or "").strip() - - if provided_hash and provided_hash_exists is True: - # If client supplied a hash that we know exists, drain but do not write to disk - try: - while True: - chunk = await field.read_chunk(8 * 1024 * 1024) - if not chunk: - break - file_written += len(chunk) - except Exception: - return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.") - continue # Do not create temp file; we will create AssetInfo from the existing content - - # Otherwise, store to temp for hashing/ingest - uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") - unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) - os.makedirs(unique_dir, exist_ok=True) - tmp_path = os.path.join(unique_dir, ".upload.part") - - try: - with open(tmp_path, "wb") as f: - while True: - chunk = await field.read_chunk(8 * 1024 * 1024) - if not chunk: - break - f.write(chunk) - file_written += len(chunk) - except Exception: - try: - if os.path.exists(tmp_path or ""): - os.remove(tmp_path) - finally: - return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") - elif fname == "tags": - tags_raw.append((await field.text()) or "") - elif fname == "name": - provided_name = (await field.text()) or None - elif fname == "user_metadata": - user_metadata_raw = (await field.text()) or None - - # If client did not send file, and we are not doing a from-hash fast path -> error - if not file_present and not (provided_hash and provided_hash_exists): - return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.") - - if file_present and file_written == 0 and not (provided_hash and provided_hash_exists): - # Empty upload is only acceptable if we are fast-pathing from existing hash - try: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - finally: - return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.") - try: - spec = schemas_in.UploadAssetSpec.model_validate({ - "tags": tags_raw, - "name": provided_name, - "user_metadata": user_metadata_raw, - "hash": provided_hash, - }) - except ValidationError as ve: - try: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - finally: - return _validation_error_response("INVALID_BODY", ve) - - # Validate models category against configured folders (consistent with previous behavior) - if spec.tags and spec.tags[0] == "models": - if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - return _error_response( - 400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'" - ) + parsed = await parse_multipart_upload(request, check_hash_exists=asset_exists) + except UploadError as e: + return _build_error_response(e.status, e.code, e.message) owner_id = USER_MANAGER.get_request_user_id(request) - # Fast path: if a valid provided hash exists, create AssetInfo without writing anything - if spec.hash and provided_hash_exists is True: - try: - result = manager.create_asset_from_hash( + try: + spec = schemas_in.UploadAssetSpec.model_validate( + { + "tags": parsed.tags_raw, + "name": parsed.provided_name, + "user_metadata": parsed.user_metadata_raw, + "hash": parsed.provided_hash, + } + ) + except ValidationError as ve: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response( + 400, "INVALID_BODY", f"Validation failed: {ve.json()}" + ) + + if spec.tags and spec.tags[0] == "models": + if ( + len(spec.tags) < 2 + or spec.tags[1] not in folder_paths.folder_names_and_paths + ): + delete_temp_file_if_exists(parsed.tmp_path) + category = spec.tags[1] if len(spec.tags) >= 2 else "" + return _build_error_response( + 400, "INVALID_BODY", f"unknown models category '{category}'" + ) + + try: + # Fast path: hash exists, create AssetReference without writing anything + if spec.hash and parsed.provided_hash_exists is True: + result = create_from_hash( hash_str=spec.hash, name=spec.name or (spec.hash.split(":", 1)[1]), tags=spec.tags, user_metadata=spec.user_metadata or {}, owner_id=owner_id, ) - except Exception: - logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id) - return _error_response(500, "INTERNAL", "Unexpected server error.") + if result is None: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist" + ) + delete_temp_file_if_exists(parsed.tmp_path) + else: + # Otherwise, we must have a temp file path to ingest + if not parsed.tmp_path or not os.path.exists(parsed.tmp_path): + return _build_error_response( + 400, + "MISSING_INPUT", + "Provided hash not found and no file uploaded.", + ) - if result is None: - return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist") - - # Drain temp if we accidentally saved (e.g., hash field came after file) - if tmp_path and os.path.exists(tmp_path): - with contextlib.suppress(Exception): - os.remove(tmp_path) - - status = 200 if (not result.created_new) else 201 - return web.json_response(result.model_dump(mode="json"), status=status) - - # Otherwise, we must have a temp file path to ingest - if not tmp_path or not os.path.exists(tmp_path): - # The only case we reach here without a temp file is: client sent a hash that does not exist and no file - return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") - - try: - created = manager.upload_asset_from_temp_path( - spec, - temp_path=tmp_path, - client_filename=file_client_name, - owner_id=owner_id, - expected_asset_hash=spec.hash, - ) - status = 201 if created.created_new else 200 - return web.json_response(created.model_dump(mode="json"), status=status) - except ValueError as e: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - msg = str(e) - if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": - return _error_response( - 400, - "HASH_MISMATCH", - "Uploaded file hash does not match provided hash.", + result = upload_from_temp_path( + temp_path=parsed.tmp_path, + name=spec.name, + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + client_filename=parsed.file_client_name, + owner_id=owner_id, + expected_hash=spec.hash, ) - return _error_response(400, "BAD_REQUEST", "Invalid inputs.") + except AssetValidationError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, e.code, str(e)) + except ValueError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, "BAD_REQUEST", str(e)) + except HashMismatchError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, "HASH_MISMATCH", str(e)) + except DependencyMissingError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(503, "DEPENDENCY_MISSING", e.message) except Exception: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id) - return _error_response(500, "INTERNAL", "Unexpected server error.") + delete_temp_file_if_exists(parsed.tmp_path) + logging.exception("upload_asset failed for owner_id=%s", owner_id) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + + payload = schemas_out.AssetCreated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash, + size=int(result.asset.size_bytes), + mime_type=result.asset.mime_type, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + created_new=result.created_new, + ) + status = 201 if result.created_new else 200 + return web.json_response(payload.model_dump(mode="json"), status=status) @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") -async def update_asset(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) +@_require_assets_feature_enabled +async def update_asset_route(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) try: body = schemas_in.UpdateAssetBody.model_validate(await request.json()) except ValidationError as ve: - return _validation_error_response("INVALID_BODY", ve) + return _build_validation_error_response("INVALID_BODY", ve) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.update_asset( - asset_info_id=asset_info_id, + result = update_asset_metadata( + reference_id=reference_id, name=body.name, user_metadata=body.user_metadata, owner_id=USER_MANAGER.get_request_user_id(request), ) - except (ValueError, PermissionError) as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + payload = schemas_out.AssetUpdated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + updated_at=result.ref.updated_at, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "update_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "update_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") -async def delete_asset(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) - delete_content = request.query.get("delete_content") - delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} +@_require_assets_feature_enabled +async def delete_asset_route(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) + delete_content_param = request.query.get("delete_content") + delete_content = ( + False + if delete_content_param is None + else delete_content_param.lower() not in {"0", "false", "no"} + ) try: - deleted = manager.delete_asset_reference( - asset_info_id=asset_info_id, + deleted = delete_asset_reference( + reference_id=reference_id, owner_id=USER_MANAGER.get_request_user_id(request), delete_content_if_orphan=delete_content, ) except Exception: logging.exception( - "delete_asset_reference failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "delete_asset_reference failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") if not deleted: - return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.") + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"AssetReference {reference_id} not found." + ) return web.Response(status=204) @ROUTES.get("/api/tags") +@_require_assets_feature_enabled async def get_tags(request: web.Request) -> web.Response: """ GET request to list all tags based on query parameters. @@ -415,12 +532,14 @@ async def get_tags(request: web.Request) -> web.Response: try: query = schemas_in.TagsListQuery.model_validate(query_map) except ValidationError as e: - return web.json_response( - {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}}, - status=400, + return _build_error_response( + 400, + "INVALID_QUERY", + "Invalid query parameters", + {"errors": json.loads(e.json())}, ) - result = manager.list_tags( + rows, total = list_tags( prefix=query.prefix, limit=query.limit, offset=query.offset, @@ -428,87 +547,212 @@ async def get_tags(request: web.Request) -> web.Response: include_zero=query.include_zero, owner_id=USER_MANAGER.get_request_user_id(request), ) - return web.json_response(result.model_dump(mode="json")) + + tags = [ + schemas_out.TagUsage(name=name, count=count, type=tag_type) + for (name, tag_type, count) in rows + ] + payload = schemas_out.TagsList( + tags=tags, total=total, has_more=(query.offset + len(tags)) < total + ) + return web.json_response(payload.model_dump(mode="json")) @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled async def add_asset_tags(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - payload = await request.json() - data = schemas_in.TagsAdd.model_validate(payload) + json_payload = await request.json() + data = schemas_in.TagsAdd.model_validate(json_payload) except ValidationError as ve: - return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags add.", + {"errors": ve.errors()}, + ) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.add_tags_to_asset( - asset_info_id=asset_info_id, + result = apply_tags( + reference_id=reference_id, tags=data.tags, origin="manual", owner_id=USER_MANAGER.get_request_user_id(request), ) - except (ValueError, PermissionError) as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + payload = schemas_out.TagsAdd( + added=result.added, + already_present=result.already_present, + total_tags=result.total_tags, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "add_tags_to_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "add_tags_to_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled async def delete_asset_tags(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - payload = await request.json() - data = schemas_in.TagsRemove.model_validate(payload) + json_payload = await request.json() + data = schemas_in.TagsRemove.model_validate(json_payload) except ValidationError as ve: - return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags remove.", + {"errors": ve.errors()}, + ) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.remove_tags_from_asset( - asset_info_id=asset_info_id, + result = remove_tags( + reference_id=reference_id, tags=data.tags, owner_id=USER_MANAGER.get_request_user_id(request), ) + payload = schemas_out.TagsRemove( + removed=result.removed, + not_present=result.not_present, + total_tags=result.total_tags, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) except ValueError as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "remove_tags_from_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.post("/api/assets/seed") -async def seed_assets_endpoint(request: web.Request) -> web.Response: - """Trigger asset seeding for specified roots (models, input, output).""" +@_require_assets_feature_enabled +async def seed_assets(request: web.Request) -> web.Response: + """Trigger asset seeding for specified roots (models, input, output). + + Query params: + wait: If "true", block until scan completes (synchronous behavior for tests) + + Returns: + 202 Accepted if scan started + 409 Conflict if scan already running + 200 OK with final stats if wait=true + """ try: payload = await request.json() roots = payload.get("roots", ["models", "input", "output"]) except Exception: roots = ["models", "input", "output"] - valid_roots = [r for r in roots if r in ("models", "input", "output")] + valid_roots = tuple(r for r in roots if r in ("models", "input", "output")) if not valid_roots: - return _error_response(400, "INVALID_BODY", "No valid roots specified") + return _build_error_response(400, "INVALID_BODY", "No valid roots specified") + wait_param = request.query.get("wait", "").lower() + should_wait = wait_param in ("true", "1", "yes") + + started = asset_seeder.start(roots=valid_roots) + if not started: + return web.json_response({"status": "already_running"}, status=409) + + if should_wait: + await asyncio.to_thread(asset_seeder.wait) + status = asset_seeder.get_status() + return web.json_response( + { + "status": "completed", + "progress": { + "scanned": status.progress.scanned if status.progress else 0, + "total": status.progress.total if status.progress else 0, + "created": status.progress.created if status.progress else 0, + "skipped": status.progress.skipped if status.progress else 0, + }, + "errors": status.errors, + }, + status=200, + ) + + return web.json_response({"status": "started"}, status=202) + + +@ROUTES.get("/api/assets/seed/status") +@_require_assets_feature_enabled +async def get_seed_status(request: web.Request) -> web.Response: + """Get current scan status and progress.""" + status = asset_seeder.get_status() + return web.json_response( + { + "state": status.state.value, + "progress": { + "scanned": status.progress.scanned, + "total": status.progress.total, + "created": status.progress.created, + "skipped": status.progress.skipped, + } + if status.progress + else None, + "errors": status.errors, + }, + status=200, + ) + + +@ROUTES.post("/api/assets/seed/cancel") +@_require_assets_feature_enabled +async def cancel_seed(request: web.Request) -> web.Response: + """Request cancellation of in-progress scan.""" + cancelled = asset_seeder.cancel() + if cancelled: + return web.json_response({"status": "cancelling"}, status=200) + return web.json_response({"status": "idle"}, status=200) + + +@ROUTES.post("/api/assets/prune") +@_require_assets_feature_enabled +async def mark_missing_assets(request: web.Request) -> web.Response: + """Mark assets as missing when outside all known root prefixes. + + This is a non-destructive soft-delete operation. Assets and metadata + are preserved, but references are flagged as missing. They can be + restored if the file reappears in a future scan. + + Returns: + 200 OK with count of marked assets + 409 Conflict if a scan is currently running + """ try: - seed_assets(tuple(valid_roots)) - except Exception: - logging.exception("seed_assets failed for roots=%s", valid_roots) - return _error_response(500, "INTERNAL", "Seed operation failed") - - return web.json_response({"seeded": valid_roots}, status=200) + marked = asset_seeder.mark_missing_outside_prefixes() + except ScanInProgressError: + return web.json_response( + {"status": "scan_running", "marked": 0}, + status=409, + ) + return web.json_response({"status": "completed", "marked": marked}, status=200) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 6707ffb0c..d255c938e 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -1,6 +1,8 @@ import json +from dataclasses import dataclass from typing import Any, Literal +from app.assets.helpers import validate_blake3_hash from pydantic import ( BaseModel, ConfigDict, @@ -10,6 +12,41 @@ from pydantic import ( model_validator, ) + +class UploadError(Exception): + """Error during upload parsing with HTTP status and code.""" + + def __init__(self, status: int, code: str, message: str): + super().__init__(message) + self.status = status + self.code = code + self.message = message + + +class AssetValidationError(Exception): + """Validation error in asset processing (invalid tags, metadata, etc.).""" + + def __init__(self, code: str, message: str): + super().__init__(message) + self.code = code + self.message = message + + +@dataclass +class ParsedUpload: + """Result of parsing a multipart upload request.""" + + file_present: bool + file_written: int + file_client_name: str | None + tmp_path: str | None + tags_raw: list[str] + provided_name: str | None + user_metadata_raw: str | None + provided_hash: str | None + provided_hash_exists: bool | None + + class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) @@ -21,7 +58,9 @@ class ListAssetsQuery(BaseModel): limit: conint(ge=1, le=500) = 20 offset: conint(ge=0) = 0 - sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = ( + "created_at" + ) order: Literal["asc", "desc"] = "desc" @field_validator("include_tags", "exclude_tags", mode="before") @@ -61,7 +100,7 @@ class UpdateAssetBody(BaseModel): user_metadata: dict[str, Any] | None = None @model_validator(mode="after") - def _at_least_one(self): + def _validate_at_least_one_field(self): if self.name is None and self.user_metadata is None: raise ValueError("Provide at least one of: name, user_metadata.") return self @@ -78,19 +117,11 @@ class CreateFromHashBody(BaseModel): @field_validator("hash") @classmethod def _require_blake3(cls, v): - s = (v or "").strip().lower() - if ":" not in s: - raise ValueError("hash must be 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3": - raise ValueError("only canonical 'blake3:' is accepted here") - if not digest or any(c for c in digest if c not in "0123456789abcdef"): - raise ValueError("hash digest must be lowercase hex") - return s + return validate_blake3_hash(v or "") @field_validator("tags", mode="before") @classmethod - def _tags_norm(cls, v): + def _normalize_tags_field(cls, v): if v is None: return [] if isinstance(v, list): @@ -154,15 +185,16 @@ class TagsRemove(TagsAdd): class UploadAssetSpec(BaseModel): """Upload Asset operation. + - tags: ordered; first is root ('models'|'input'|'output'); - if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths + if root == 'models', second must be a valid category - name: display name - user_metadata: arbitrary JSON object (optional) - - hash: optional canonical 'blake3:' provided by the client for validation / fast-path + - hash: optional canonical 'blake3:' for validation / fast-path - Files created via this endpoint are stored on disk using the **content hash** as the filename stem - and the original extension is preserved when available. + Files are stored using the content hash as filename stem. """ + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) tags: list[str] = Field(..., min_length=1) @@ -175,17 +207,10 @@ class UploadAssetSpec(BaseModel): def _parse_hash(cls, v): if v is None: return None - s = str(v).strip().lower() + s = str(v).strip() if not s: return None - if ":" not in s: - raise ValueError("hash must be 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3": - raise ValueError("only canonical 'blake3:' is accepted here") - if not digest or any(c for c in digest if c not in "0123456789abcdef"): - raise ValueError("hash digest must be lowercase hex") - return f"{algo}:{digest}" + return validate_blake3_hash(s) @field_validator("tags", mode="before") @classmethod @@ -260,5 +285,7 @@ class UploadAssetSpec(BaseModel): raise ValueError("first tag must be one of: models, input, output") if root == "models": if len(self.tags) < 2: - raise ValueError("models uploads require a category tag as the second tag") + raise ValueError( + "models uploads require a category tag as the second tag" + ) return self diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index b6fb3da0c..f36447856 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -19,7 +19,7 @@ class AssetSummary(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("created_at", "updated_at", "last_access_time") - def _ser_dt(self, v: datetime | None, _info): + def _serialize_datetime(self, v: datetime | None, _info): return v.isoformat() if v else None @@ -40,7 +40,7 @@ class AssetUpdated(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("updated_at") - def _ser_updated(self, v: datetime | None, _info): + def _serialize_updated_at(self, v: datetime | None, _info): return v.isoformat() if v else None @@ -59,7 +59,7 @@ class AssetDetail(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("created_at", "last_access_time") - def _ser_dt(self, v: datetime | None, _info): + def _serialize_datetime(self, v: datetime | None, _info): return v.isoformat() if v else None diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py new file mode 100644 index 000000000..721c12f4d --- /dev/null +++ b/app/assets/api/upload.py @@ -0,0 +1,171 @@ +import logging +import os +import uuid +from typing import Callable + +from aiohttp import web + +import folder_paths +from app.assets.api.schemas_in import ParsedUpload, UploadError +from app.assets.helpers import validate_blake3_hash + + +def normalize_and_validate_hash(s: str) -> str: + """Validate and normalize a hash string. + + Returns canonical 'blake3:' or raises UploadError. + """ + try: + return validate_blake3_hash(s) + except ValueError: + raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") + + +async def parse_multipart_upload( + request: web.Request, + check_hash_exists: Callable[[str], bool], +) -> ParsedUpload: + """ + Parse a multipart/form-data upload request. + + Args: + request: The aiohttp request + check_hash_exists: Callable(hash_str) -> bool to check if a hash exists + + Returns: + ParsedUpload with parsed fields and temp file path + + Raises: + UploadError: On validation or I/O errors + """ + if not (request.content_type or "").lower().startswith("multipart/"): + raise UploadError( + 415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads." + ) + + reader = await request.multipart() + + file_present = False + file_client_name: str | None = None + tags_raw: list[str] = [] + provided_name: str | None = None + user_metadata_raw: str | None = None + provided_hash: str | None = None + provided_hash_exists: bool | None = None + + file_written = 0 + tmp_path: str | None = None + + while True: + field = await reader.next() + if field is None: + break + + fname = getattr(field, "name", "") or "" + + if fname == "hash": + try: + s = ((await field.text()) or "").strip().lower() + except Exception: + raise UploadError( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) + + if s: + provided_hash = normalize_and_validate_hash(s) + try: + provided_hash_exists = check_hash_exists(provided_hash) + except Exception as e: + logging.exception( + "check_hash_exists failed for hash=%s: %s", provided_hash, e + ) + raise UploadError( + 500, + "HASH_CHECK_FAILED", + "Backend error while checking asset hash.", + ) + + elif fname == "file": + file_present = True + file_client_name = (field.filename or "").strip() + + if provided_hash and provided_hash_exists is True: + # Hash exists - drain file but don't write to disk + try: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + file_written += len(chunk) + except Exception: + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file." + ) + continue + + uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") + unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) + os.makedirs(unique_dir, exist_ok=True) + tmp_path = os.path.join(unique_dir, ".upload.part") + + try: + with open(tmp_path, "wb") as f: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + file_written += len(chunk) + except Exception: + delete_temp_file_if_exists(tmp_path) + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file." + ) + + elif fname == "tags": + tags_raw.append((await field.text()) or "") + elif fname == "name": + provided_name = (await field.text()) or None + elif fname == "user_metadata": + user_metadata_raw = (await field.text()) or None + + if not file_present and not (provided_hash and provided_hash_exists): + raise UploadError( + 400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'." + ) + + if ( + file_present + and file_written == 0 + and not (provided_hash and provided_hash_exists) + ): + delete_temp_file_if_exists(tmp_path) + raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.") + + return ParsedUpload( + file_present=file_present, + file_written=file_written, + file_client_name=file_client_name, + tmp_path=tmp_path, + tags_raw=tags_raw, + provided_name=provided_name, + user_metadata_raw=user_metadata_raw, + provided_hash=provided_hash, + provided_hash_exists=provided_hash_exists, + ) + + +def delete_temp_file_if_exists(tmp_path: str | None) -> None: + """Safely remove a temp file and its parent directory if empty.""" + if tmp_path: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError as e: + logging.debug("Failed to delete temp file %s: %s", tmp_path, e) + try: + parent = os.path.dirname(tmp_path) + if parent and os.path.isdir(parent): + os.rmdir(parent) # only succeeds if empty + except OSError: + pass diff --git a/app/assets/database/bulk_ops.py b/app/assets/database/bulk_ops.py deleted file mode 100644 index c7b75290a..000000000 --- a/app/assets/database/bulk_ops.py +++ /dev/null @@ -1,204 +0,0 @@ -import os -import uuid -import sqlalchemy -from typing import Iterable -from sqlalchemy.orm import Session -from sqlalchemy.dialects import sqlite - -from app.assets.helpers import utcnow -from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta - -MAX_BIND_PARAMS = 800 - -def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]: - if not rows: - return [] - rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row)) - for i in range(0, len(rows), rows_per_stmt): - yield rows[i:i + rows_per_stmt] - -def _iter_chunks(seq, n: int): - for i in range(0, len(seq), n): - yield seq[i:i + n] - -def _rows_per_stmt(cols: int) -> int: - return max(1, MAX_BIND_PARAMS // max(1, cols)) - - -def seed_from_paths_batch( - session: Session, - *, - specs: list[dict], - owner_id: str = "", -) -> dict: - """Each spec is a dict with keys: - - abs_path: str - - size_bytes: int - - mtime_ns: int - - info_name: str - - tags: list[str] - - fname: Optional[str] - """ - if not specs: - return {"inserted_infos": 0, "won_states": 0, "lost_states": 0} - - now = utcnow() - asset_rows: list[dict] = [] - state_rows: list[dict] = [] - path_to_asset: dict[str, str] = {} - asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row - path_list: list[str] = [] - - for sp in specs: - ap = os.path.abspath(sp["abs_path"]) - aid = str(uuid.uuid4()) - iid = str(uuid.uuid4()) - path_list.append(ap) - path_to_asset[ap] = aid - - asset_rows.append( - { - "id": aid, - "hash": None, - "size_bytes": sp["size_bytes"], - "mime_type": None, - "created_at": now, - } - ) - state_rows.append( - { - "asset_id": aid, - "file_path": ap, - "mtime_ns": sp["mtime_ns"], - } - ) - asset_to_info[aid] = { - "id": iid, - "owner_id": owner_id, - "name": sp["info_name"], - "asset_id": aid, - "preview_id": None, - "user_metadata": {"filename": sp["fname"]} if sp["fname"] else None, - "created_at": now, - "updated_at": now, - "last_access_time": now, - "_tags": sp["tags"], - "_filename": sp["fname"], - } - - # insert all seed Assets (hash=NULL) - ins_asset = sqlite.insert(Asset) - for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)): - session.execute(ins_asset, chunk) - - # try to claim AssetCacheState (file_path) - # Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted - ins_state = ( - sqlite.insert(AssetCacheState) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): - session.execute(ins_state, chunk) - - # Query to find which of our paths won (were actually inserted) - winners_by_path: set[str] = set() - for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS): - result = session.execute( - sqlalchemy.select(AssetCacheState.file_path) - .where(AssetCacheState.file_path.in_(chunk)) - .where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk])) - ) - winners_by_path.update(result.scalars().all()) - - all_paths_set = set(path_list) - losers_by_path = all_paths_set - winners_by_path - lost_assets = [path_to_asset[p] for p in losers_by_path] - if lost_assets: # losers get their Asset removed - for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS): - session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk))) - - if not winners_by_path: - return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} - - # insert AssetInfo only for winners - # Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted - winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] - ins_info = ( - sqlite.insert(AssetInfo) - .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) - ) - for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): - session.execute(ins_info, chunk) - - # Query to find which info rows were actually inserted (by matching our generated IDs) - all_info_ids = [row["id"] for row in winner_info_rows] - inserted_info_ids: set[str] = set() - for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS): - result = session.execute( - sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk)) - ) - inserted_info_ids.update(result.scalars().all()) - - # build and insert tag + meta rows for the AssetInfo - tag_rows: list[dict] = [] - meta_rows: list[dict] = [] - if inserted_info_ids: - for row in winner_info_rows: - iid = row["id"] - if iid not in inserted_info_ids: - continue - for t in row["_tags"]: - tag_rows.append({ - "asset_info_id": iid, - "tag_name": t, - "origin": "automatic", - "added_at": now, - }) - if row["_filename"]: - meta_rows.append( - { - "asset_info_id": iid, - "key": "filename", - "ordinal": 0, - "val_str": row["_filename"], - "val_num": None, - "val_bool": None, - "val_json": None, - } - ) - - bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS) - return { - "inserted_infos": len(inserted_info_ids), - "won_states": len(winners_by_path), - "lost_states": len(losers_by_path), - } - - -def bulk_insert_tags_and_meta( - session: Session, - *, - tag_rows: list[dict], - meta_rows: list[dict], - max_bind_params: int, -) -> None: - """Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING. - - tag_rows keys: asset_info_id, tag_name, origin, added_at - - meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json - """ - if tag_rows: - ins_links = ( - sqlite.insert(AssetInfoTag) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params): - session.execute(ins_links, chunk) - if meta_rows: - ins_meta = ( - sqlite.insert(AssetInfoMeta) - .on_conflict_do_nothing( - index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] - ) - ) - for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params): - session.execute(ins_meta, chunk) diff --git a/app/assets/database/models.py b/app/assets/database/models.py index 3cd28f68b..03c1c1707 100644 --- a/app/assets/database/models.py +++ b/app/assets/database/models.py @@ -2,8 +2,8 @@ from __future__ import annotations import uuid from datetime import datetime - from typing import Any + from sqlalchemy import ( JSON, BigInteger, @@ -16,102 +16,102 @@ from sqlalchemy import ( Numeric, String, Text, - UniqueConstraint, ) from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship -from app.assets.helpers import utcnow -from app.database.models import to_dict, Base +from app.assets.helpers import get_utc_now +from app.database.models import Base class Asset(Base): __tablename__ = "assets" - id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) hash: Mapped[str | None] = mapped_column(String(256), nullable=True) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) mime_type: Mapped[str | None] = mapped_column(String(255)) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow + DateTime(timezone=False), nullable=False, default=get_utc_now ) - infos: Mapped[list[AssetInfo]] = relationship( - "AssetInfo", + references: Mapped[list[AssetReference]] = relationship( + "AssetReference", back_populates="asset", - primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), - foreign_keys=lambda: [AssetInfo.asset_id], + primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id), + foreign_keys=lambda: [AssetReference.asset_id], cascade="all,delete-orphan", passive_deletes=True, ) - preview_of: Mapped[list[AssetInfo]] = relationship( - "AssetInfo", + preview_of: Mapped[list[AssetReference]] = relationship( + "AssetReference", back_populates="preview_asset", - primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id), - foreign_keys=lambda: [AssetInfo.preview_id], + primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id), + foreign_keys=lambda: [AssetReference.preview_id], viewonly=True, ) - cache_states: Mapped[list[AssetCacheState]] = relationship( - back_populates="asset", - cascade="all, delete-orphan", - passive_deletes=True, - ) - __table_args__ = ( Index("uq_assets_hash", "hash", unique=True), Index("ix_assets_mime_type", "mime_type"), CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - return to_dict(self, include_none=include_none) - def __repr__(self) -> str: return f"" -class AssetCacheState(Base): - __tablename__ = "asset_cache_state" +class AssetReference(Base): + """Unified model combining file cache state and user-facing metadata. - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) - file_path: Mapped[str] = mapped_column(Text, nullable=False) - mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) - needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + Each row represents either: + - A filesystem reference (file_path is set) with cache state + - An API-created reference (file_path is NULL) without cache state + """ - asset: Mapped[Asset] = relationship(back_populates="cache_states") + __tablename__ = "asset_references" - __table_args__ = ( - Index("ix_asset_cache_state_file_path", "file_path"), - Index("ix_asset_cache_state_asset_id", "asset_id"), - CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), - UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) + asset_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - return to_dict(self, include_none=include_none) + # Cache state fields (from former AssetCacheState) + file_path: Mapped[str | None] = mapped_column(Text, nullable=True) + mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - def __repr__(self) -> str: - return f"" - - -class AssetInfo(Base): - __tablename__ = "assets_info" - - id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + # Info fields (from former AssetInfo) owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) - asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) - preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) - user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) - last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + preview_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="SET NULL") + ) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column( + JSON(none_as_null=True) + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + last_access_time: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + deleted_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=False), nullable=True, default=None + ) asset: Mapped[Asset] = relationship( "Asset", - back_populates="infos", + back_populates="references", foreign_keys=[asset_id], lazy="selectin", ) @@ -121,51 +121,59 @@ class AssetInfo(Base): foreign_keys=[preview_id], ) - metadata_entries: Mapped[list[AssetInfoMeta]] = relationship( - back_populates="asset_info", + metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship( + back_populates="asset_reference", cascade="all,delete-orphan", passive_deletes=True, ) - tag_links: Mapped[list[AssetInfoTag]] = relationship( - back_populates="asset_info", + tag_links: Mapped[list[AssetReferenceTag]] = relationship( + back_populates="asset_reference", cascade="all,delete-orphan", passive_deletes=True, - overlaps="tags,asset_infos", + overlaps="tags,asset_references", ) tags: Mapped[list[Tag]] = relationship( - secondary="asset_info_tags", - back_populates="asset_infos", + secondary="asset_reference_tags", + back_populates="asset_references", lazy="selectin", viewonly=True, - overlaps="tag_links,asset_info_links,asset_infos,tag", + overlaps="tag_links,asset_reference_links,asset_references,tag", ) __table_args__ = ( - UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), - Index("ix_assets_info_owner_name", "owner_id", "name"), - Index("ix_assets_info_owner_id", "owner_id"), - Index("ix_assets_info_asset_id", "asset_id"), - Index("ix_assets_info_name", "name"), - Index("ix_assets_info_created_at", "created_at"), - Index("ix_assets_info_last_access_time", "last_access_time"), + Index("uq_asset_references_file_path", "file_path", unique=True), + Index("ix_asset_references_asset_id", "asset_id"), + Index("ix_asset_references_owner_id", "owner_id"), + Index("ix_asset_references_name", "name"), + Index("ix_asset_references_is_missing", "is_missing"), + Index("ix_asset_references_enrichment_level", "enrichment_level"), + Index("ix_asset_references_created_at", "created_at"), + Index("ix_asset_references_last_access_time", "last_access_time"), + Index("ix_asset_references_deleted_at", "deleted_at"), + Index("ix_asset_references_owner_name", "owner_id", "name"), + CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" + ), + CheckConstraint( + "enrichment_level >= 0 AND enrichment_level <= 2", + name="ck_ar_enrichment_level_range", + ), ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - data = to_dict(self, include_none=include_none) - data["tags"] = [t.name for t in self.tags] - return data - def __repr__(self) -> str: - return f"" + path_part = f" path={self.file_path!r}" if self.file_path else "" + return f"" -class AssetInfoMeta(Base): - __tablename__ = "asset_info_meta" +class AssetReferenceMeta(Base): + __tablename__ = "asset_reference_meta" - asset_info_id: Mapped[str] = mapped_column( - String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + asset_reference_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("asset_references.id", ondelete="CASCADE"), + primary_key=True, ) key: Mapped[str] = mapped_column(String(256), primary_key=True) ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) @@ -175,36 +183,40 @@ class AssetInfoMeta(Base): val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True) val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True) - asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries") + asset_reference: Mapped[AssetReference] = relationship( + back_populates="metadata_entries" + ) __table_args__ = ( - Index("ix_asset_info_meta_key", "key"), - Index("ix_asset_info_meta_key_val_str", "key", "val_str"), - Index("ix_asset_info_meta_key_val_num", "key", "val_num"), - Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), + Index("ix_asset_reference_meta_key", "key"), + Index("ix_asset_reference_meta_key_val_str", "key", "val_str"), + Index("ix_asset_reference_meta_key_val_num", "key", "val_num"), + Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"), ) -class AssetInfoTag(Base): - __tablename__ = "asset_info_tags" +class AssetReferenceTag(Base): + __tablename__ = "asset_reference_tags" - asset_info_id: Mapped[str] = mapped_column( - String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + asset_reference_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("asset_references.id", ondelete="CASCADE"), + primary_key=True, ) tag_name: Mapped[str] = mapped_column( String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True ) origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") added_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow + DateTime(timezone=False), nullable=False, default=get_utc_now ) - asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links") - tag: Mapped[Tag] = relationship(back_populates="asset_info_links") + asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links") + tag: Mapped[Tag] = relationship(back_populates="asset_reference_links") __table_args__ = ( - Index("ix_asset_info_tags_tag_name", "tag_name"), - Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), + Index("ix_asset_reference_tags_tag_name", "tag_name"), + Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"), ) @@ -214,20 +226,18 @@ class Tag(Base): name: Mapped[str] = mapped_column(String(512), primary_key=True) tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") - asset_info_links: Mapped[list[AssetInfoTag]] = relationship( + asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship( back_populates="tag", - overlaps="asset_infos,tags", + overlaps="asset_references,tags", ) - asset_infos: Mapped[list[AssetInfo]] = relationship( - secondary="asset_info_tags", + asset_references: Mapped[list[AssetReference]] = relationship( + secondary="asset_reference_tags", back_populates="tags", viewonly=True, - overlaps="asset_info_links,tag_links,tags,asset_info", + overlaps="asset_reference_links,tag_links,tags,asset_reference", ) - __table_args__ = ( - Index("ix_tags_tag_type", "tag_type"), - ) + __table_args__ = (Index("ix_tags_tag_type", "tag_type"),) def __repr__(self) -> str: return f"" diff --git a/app/assets/database/queries.py b/app/assets/database/queries.py deleted file mode 100644 index d6b33ec7b..000000000 --- a/app/assets/database/queries.py +++ /dev/null @@ -1,976 +0,0 @@ -import os -import logging -import sqlalchemy as sa -from collections import defaultdict -from datetime import datetime -from typing import Iterable, Any -from sqlalchemy import select, delete, exists, func -from sqlalchemy.dialects import sqlite -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session, contains_eager, noload -from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag -from app.assets.helpers import ( - compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow -) -from typing import Sequence - - -def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: - """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" - owner_id = (owner_id or "").strip() - if owner_id == "": - return AssetInfo.owner_id == "" - return AssetInfo.owner_id.in_(["", owner_id]) - - -def pick_best_live_path(states: Sequence[AssetCacheState]) -> str: - """ - Return the best on-disk path among cache states: - 1) Prefer a path that exists with needs_verify == False (already verified). - 2) Otherwise, pick the first path that exists. - 3) Otherwise return empty string. - """ - alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] - if not alive: - return "" - for s in alive: - if not getattr(s, "needs_verify", False): - return s.file_path - return alive[0].file_path - - -def apply_tag_filters( - stmt: sa.sql.Select, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, -) -> sa.sql.Select: - """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = normalize_tags(include_tags) - exclude_tags = normalize_tags(exclude_tags) - - if include_tags: - for tag_name in include_tags: - stmt = stmt.where( - exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name == tag_name) - ) - ) - - if exclude_tags: - stmt = stmt.where( - ~exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name.in_(exclude_tags)) - ) - ) - return stmt - - -def apply_metadata_filter( - stmt: sa.sql.Select, - metadata_filter: dict | None = None, -) -> sa.sql.Select: - """Apply filters using asset_info_meta projection table.""" - if not metadata_filter: - return stmt - - def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - return sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - *preds, - ) - - def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: - if value is None: - no_row_for_key = sa.not_( - sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - ) - ) - null_row = _exists_for_pred( - key, - AssetInfoMeta.val_json.is_(None), - AssetInfoMeta.val_str.is_(None), - AssetInfoMeta.val_num.is_(None), - AssetInfoMeta.val_bool.is_(None), - ) - return sa.or_(no_row_for_key, null_row) - - if isinstance(value, bool): - return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) - if isinstance(value, (int, float)): - from decimal import Decimal - num = value if isinstance(value, Decimal) else Decimal(str(value)) - return _exists_for_pred(key, AssetInfoMeta.val_num == num) - if isinstance(value, str): - return _exists_for_pred(key, AssetInfoMeta.val_str == value) - return _exists_for_pred(key, AssetInfoMeta.val_json == value) - - for k, v in metadata_filter.items(): - if isinstance(v, list): - ors = [_exists_clause_for_value(k, elem) for elem in v] - if ors: - stmt = stmt.where(sa.or_(*ors)) - else: - stmt = stmt.where(_exists_clause_for_value(k, v)) - return stmt - - -def asset_exists_by_hash( - session: Session, - *, - asset_hash: str, -) -> bool: - """ - Check if an asset with a given hash exists in database. - """ - row = ( - session.execute( - select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).first() - return row is not None - - -def asset_info_exists_for_asset_id( - session: Session, - *, - asset_id: str, -) -> bool: - q = ( - select(sa.literal(True)) - .select_from(AssetInfo) - .where(AssetInfo.asset_id == asset_id) - .limit(1) - ) - return (session.execute(q)).first() is not None - - -def get_asset_by_hash( - session: Session, - *, - asset_hash: str, -) -> Asset | None: - return ( - session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - - -def get_asset_info_by_id( - session: Session, - *, - asset_info_id: str, -) -> AssetInfo | None: - return session.get(AssetInfo, asset_info_id) - - -def list_asset_infos_page( - session: Session, - owner_id: str = "", - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, - name_contains: str | None = None, - metadata_filter: dict | None = None, - limit: int = 20, - offset: int = 0, - sort: str = "created_at", - order: str = "desc", -) -> tuple[list[AssetInfo], dict[str, list[str]], int]: - base = ( - select(AssetInfo) - .join(Asset, Asset.id == AssetInfo.asset_id) - .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) - .where(visible_owner_clause(owner_id)) - ) - - if name_contains: - escaped, esc = escape_like_prefix(name_contains) - base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) - - base = apply_tag_filters(base, include_tags, exclude_tags) - base = apply_metadata_filter(base, metadata_filter) - - sort = (sort or "created_at").lower() - order = (order or "desc").lower() - sort_map = { - "name": AssetInfo.name, - "created_at": AssetInfo.created_at, - "updated_at": AssetInfo.updated_at, - "last_access_time": AssetInfo.last_access_time, - "size": Asset.size_bytes, - } - sort_col = sort_map.get(sort, AssetInfo.created_at) - sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() - - base = base.order_by(sort_exp).limit(limit).offset(offset) - - count_stmt = ( - select(sa.func.count()) - .select_from(AssetInfo) - .join(Asset, Asset.id == AssetInfo.asset_id) - .where(visible_owner_clause(owner_id)) - ) - if name_contains: - escaped, esc = escape_like_prefix(name_contains) - count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) - count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = apply_metadata_filter(count_stmt, metadata_filter) - - total = int((session.execute(count_stmt)).scalar_one() or 0) - - infos = (session.execute(base)).unique().scalars().all() - - id_list: list[str] = [i.id for i in infos] - tag_map: dict[str, list[str]] = defaultdict(list) - if id_list: - rows = session.execute( - select(AssetInfoTag.asset_info_id, Tag.name) - .join(Tag, Tag.name == AssetInfoTag.tag_name) - .where(AssetInfoTag.asset_info_id.in_(id_list)) - .order_by(AssetInfoTag.added_at) - ) - for aid, tag_name in rows.all(): - tag_map[aid].append(tag_name) - - return infos, tag_map, total - - -def fetch_asset_info_asset_and_tags( - session: Session, - asset_info_id: str, - owner_id: str = "", -) -> tuple[AssetInfo, Asset, list[str]] | None: - stmt = ( - select(AssetInfo, Asset, Tag.name) - .join(Asset, Asset.id == AssetInfo.asset_id) - .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) - .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .options(noload(AssetInfo.tags)) - .order_by(Tag.name.asc()) - ) - - rows = (session.execute(stmt)).all() - if not rows: - return None - - first_info, first_asset, _ = rows[0] - tags: list[str] = [] - seen: set[str] = set() - for _info, _asset, tag_name in rows: - if tag_name and tag_name not in seen: - seen.add(tag_name) - tags.append(tag_name) - return first_info, first_asset, tags - - -def fetch_asset_info_and_asset( - session: Session, - *, - asset_info_id: str, - owner_id: str = "", -) -> tuple[AssetInfo, Asset] | None: - stmt = ( - select(AssetInfo, Asset) - .join(Asset, Asset.id == AssetInfo.asset_id) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .limit(1) - .options(noload(AssetInfo.tags)) - ) - row = session.execute(stmt) - pair = row.first() - if not pair: - return None - return pair[0], pair[1] - -def list_cache_states_by_asset_id( - session: Session, *, asset_id: str -) -> Sequence[AssetCacheState]: - return ( - session.execute( - select(AssetCacheState) - .where(AssetCacheState.asset_id == asset_id) - .order_by(AssetCacheState.id.asc()) - ) - ).scalars().all() - - -def touch_asset_info_by_id( - session: Session, - *, - asset_info_id: str, - ts: datetime | None = None, - only_if_newer: bool = True, -) -> None: - ts = ts or utcnow() - stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) - if only_if_newer: - stmt = stmt.where( - sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) - ) - session.execute(stmt.values(last_access_time=ts)) - - -def create_asset_info_for_existing_asset( - session: Session, - *, - asset_hash: str, - name: str, - user_metadata: dict | None = None, - tags: Sequence[str] | None = None, - tag_origin: str = "manual", - owner_id: str = "", -) -> AssetInfo: - """Create or return an existing AssetInfo for an Asset identified by asset_hash.""" - now = utcnow() - asset = get_asset_by_hash(session, asset_hash=asset_hash) - if not asset: - raise ValueError(f"Unknown asset hash {asset_hash}") - - info = AssetInfo( - owner_id=owner_id, - name=name, - asset_id=asset.id, - preview_id=None, - created_at=now, - updated_at=now, - last_access_time=now, - ) - try: - with session.begin_nested(): - session.add(info) - session.flush() - except IntegrityError: - existing = ( - session.execute( - select(AssetInfo) - .options(noload(AssetInfo.tags)) - .where( - AssetInfo.asset_id == asset.id, - AssetInfo.name == name, - AssetInfo.owner_id == owner_id, - ) - .limit(1) - ) - ).unique().scalars().first() - if not existing: - raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") - return existing - - # metadata["filename"] hack - new_meta = dict(user_metadata or {}) - computed_filename = None - try: - p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) - if p: - computed_filename = compute_relative_filename(p) - except Exception: - computed_filename = None - if computed_filename: - new_meta["filename"] = computed_filename - if new_meta: - replace_asset_info_metadata_projection( - session, - asset_info_id=info.id, - user_metadata=new_meta, - ) - - if tags is not None: - set_asset_info_tags( - session, - asset_info_id=info.id, - tags=tags, - origin=tag_origin, - ) - return info - - -def set_asset_info_tags( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", -) -> dict: - desired = normalize_tags(tags) - - current = set( - tag_name for (tag_name,) in ( - session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) - ).all() - ) - - to_add = [t for t in desired if t not in current] - to_remove = [t for t in current if t not in desired] - - if to_add: - ensure_tags_exist(session, to_add, tag_type="user") - session.add_all([ - AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) - for t in to_add - ]) - session.flush() - - if to_remove: - session.execute( - delete(AssetInfoTag) - .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) - ) - session.flush() - - return {"added": to_add, "removed": to_remove, "total": desired} - - -def replace_asset_info_metadata_projection( - session: Session, - *, - asset_info_id: str, - user_metadata: dict | None = None, -) -> None: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - info.user_metadata = user_metadata or {} - info.updated_at = utcnow() - session.flush() - - session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) - session.flush() - - if not user_metadata: - return - - rows: list[AssetInfoMeta] = [] - for k, v in user_metadata.items(): - for r in project_kv(k, v): - rows.append( - AssetInfoMeta( - asset_info_id=asset_info_id, - key=r["key"], - ordinal=int(r["ordinal"]), - val_str=r.get("val_str"), - val_num=r.get("val_num"), - val_bool=r.get("val_bool"), - val_json=r.get("val_json"), - ) - ) - if rows: - session.add_all(rows) - session.flush() - - -def ingest_fs_asset( - session: Session, - *, - asset_hash: str, - abs_path: str, - size_bytes: int, - mtime_ns: int, - mime_type: str | None = None, - info_name: str | None = None, - owner_id: str = "", - preview_id: str | None = None, - user_metadata: dict | None = None, - tags: Sequence[str] = (), - tag_origin: str = "manual", - require_existing_tags: bool = False, -) -> dict: - """ - Idempotently upsert: - - Asset by content hash (create if missing) - - AssetCacheState(file_path) pointing to asset_id - - Optionally AssetInfo + tag links and metadata projection - Returns flags and ids. - """ - locator = os.path.abspath(abs_path) - now = utcnow() - - if preview_id: - if not session.get(Asset, preview_id): - preview_id = None - - out: dict[str, Any] = { - "asset_created": False, - "asset_updated": False, - "state_created": False, - "state_updated": False, - "asset_info_id": None, - } - - # 1) Asset by hash - asset = ( - session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - if not asset: - vals = { - "hash": asset_hash, - "size_bytes": int(size_bytes), - "mime_type": mime_type, - "created_at": now, - } - res = session.execute( - sqlite.insert(Asset) - .values(**vals) - .on_conflict_do_nothing(index_elements=[Asset.hash]) - ) - if int(res.rowcount or 0) > 0: - out["asset_created"] = True - asset = ( - session.execute( - select(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).scalars().first() - if not asset: - raise RuntimeError("Asset row not found after upsert.") - else: - changed = False - if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: - asset.size_bytes = int(size_bytes) - changed = True - if mime_type and asset.mime_type != mime_type: - asset.mime_type = mime_type - changed = True - if changed: - out["asset_updated"] = True - - # 2) AssetCacheState upsert by file_path (unique) - vals = { - "asset_id": asset.id, - "file_path": locator, - "mtime_ns": int(mtime_ns), - } - ins = ( - sqlite.insert(AssetCacheState) - .values(**vals) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - - res = session.execute(ins) - if int(res.rowcount or 0) > 0: - out["state_created"] = True - else: - upd = ( - sa.update(AssetCacheState) - .where(AssetCacheState.file_path == locator) - .where( - sa.or_( - AssetCacheState.asset_id != asset.id, - AssetCacheState.mtime_ns.is_(None), - AssetCacheState.mtime_ns != int(mtime_ns), - ) - ) - .values(asset_id=asset.id, mtime_ns=int(mtime_ns)) - ) - res2 = session.execute(upd) - if int(res2.rowcount or 0) > 0: - out["state_updated"] = True - - # 3) Optional AssetInfo + tags + metadata - if info_name: - try: - with session.begin_nested(): - info = AssetInfo( - owner_id=owner_id, - name=info_name, - asset_id=asset.id, - preview_id=preview_id, - created_at=now, - updated_at=now, - last_access_time=now, - ) - session.add(info) - session.flush() - out["asset_info_id"] = info.id - except IntegrityError: - pass - - existing_info = ( - session.execute( - select(AssetInfo) - .where( - AssetInfo.asset_id == asset.id, - AssetInfo.name == info_name, - (AssetInfo.owner_id == owner_id), - ) - .limit(1) - ) - ).unique().scalar_one_or_none() - if not existing_info: - raise RuntimeError("Failed to update or insert AssetInfo.") - - if preview_id and existing_info.preview_id != preview_id: - existing_info.preview_id = preview_id - - existing_info.updated_at = now - if existing_info.last_access_time < now: - existing_info.last_access_time = now - session.flush() - out["asset_info_id"] = existing_info.id - - norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] - if norm and out["asset_info_id"] is not None: - if not require_existing_tags: - ensure_tags_exist(session, norm, tag_type="user") - - existing_tag_names = set( - name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() - ) - missing = [t for t in norm if t not in existing_tag_names] - if missing and require_existing_tags: - raise ValueError(f"Unknown tags: {missing}") - - existing_links = set( - tag_name - for (tag_name,) in ( - session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) - ) - ).all() - ) - to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] - if to_add: - session.add_all( - [ - AssetInfoTag( - asset_info_id=out["asset_info_id"], - tag_name=t, - origin=tag_origin, - added_at=now, - ) - for t in to_add - ] - ) - session.flush() - - # metadata["filename"] hack - if out["asset_info_id"] is not None: - primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) - computed_filename = compute_relative_filename(primary_path) if primary_path else None - - current_meta = existing_info.user_metadata or {} - new_meta = dict(current_meta) - if user_metadata is not None: - for k, v in user_metadata.items(): - new_meta[k] = v - if computed_filename: - new_meta["filename"] = computed_filename - - if new_meta != current_meta: - replace_asset_info_metadata_projection( - session, - asset_info_id=out["asset_info_id"], - user_metadata=new_meta, - ) - - try: - remove_missing_tag_for_asset_id(session, asset_id=asset.id) - except Exception: - logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) - return out - - -def update_asset_info_full( - session: Session, - *, - asset_info_id: str, - name: str | None = None, - tags: Sequence[str] | None = None, - user_metadata: dict | None = None, - tag_origin: str = "manual", - asset_info_row: Any = None, -) -> AssetInfo: - if not asset_info_row: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - else: - info = asset_info_row - - touched = False - if name is not None and name != info.name: - info.name = name - touched = True - - computed_filename = None - try: - p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id)) - if p: - computed_filename = compute_relative_filename(p) - except Exception: - computed_filename = None - - if user_metadata is not None: - new_meta = dict(user_metadata) - if computed_filename: - new_meta["filename"] = computed_filename - replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - else: - if computed_filename: - current_meta = info.user_metadata or {} - if current_meta.get("filename") != computed_filename: - new_meta = dict(current_meta) - new_meta["filename"] = computed_filename - replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - - if tags is not None: - set_asset_info_tags( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=tag_origin, - ) - touched = True - - if touched and user_metadata is None: - info.updated_at = utcnow() - session.flush() - - return info - - -def delete_asset_info_by_id( - session: Session, - *, - asset_info_id: str, - owner_id: str, -) -> bool: - stmt = sa.delete(AssetInfo).where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - return int((session.execute(stmt)).rowcount or 0) > 0 - - -def list_tags_with_usage( - session: Session, - prefix: str | None = None, - limit: int = 100, - offset: int = 0, - include_zero: bool = True, - order: str = "count_desc", - owner_id: str = "", -) -> tuple[list[tuple[str, str, int]], int]: - counts_sq = ( - select( - AssetInfoTag.tag_name.label("tag_name"), - func.count(AssetInfoTag.asset_info_id).label("cnt"), - ) - .select_from(AssetInfoTag) - .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) - .where(visible_owner_clause(owner_id)) - .group_by(AssetInfoTag.tag_name) - .subquery() - ) - - q = ( - select( - Tag.name, - Tag.tag_type, - func.coalesce(counts_sq.c.cnt, 0).label("count"), - ) - .select_from(Tag) - .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) - ) - - if prefix: - escaped, esc = escape_like_prefix(prefix.strip().lower()) - q = q.where(Tag.name.like(escaped + "%", escape=esc)) - - if not include_zero: - q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) - - if order == "name_asc": - q = q.order_by(Tag.name.asc()) - else: - q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) - - total_q = select(func.count()).select_from(Tag) - if prefix: - escaped, esc = escape_like_prefix(prefix.strip().lower()) - total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) - if not include_zero: - total_q = total_q.where( - Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) - ) - - rows = (session.execute(q.limit(limit).offset(offset))).all() - total = (session.execute(total_q)).scalar_one() - - rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] - return rows_norm, int(total or 0) - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: - wanted = normalize_tags(list(names)) - if not wanted: - return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] - ins = ( - sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - session.execute(ins) - - -def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]: - return [ - tag_name for (tag_name,) in ( - session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - ] - - -def add_tags_to_asset_info( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", - create_if_missing: bool = True, - asset_info_row: Any = None, -) -> dict: - if not asset_info_row: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"added": [], "already_present": [], "total_tags": total} - - if create_if_missing: - ensure_tags_exist(session, norm, tag_type="user") - - current = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - want = set(norm) - to_add = sorted(want - current) - - if to_add: - with session.begin_nested() as nested: - try: - session.add_all( - [ - AssetInfoTag( - asset_info_id=asset_info_id, - tag_name=t, - origin=origin, - added_at=utcnow(), - ) - for t in to_add - ] - ) - session.flush() - except IntegrityError: - nested.rollback() - - after = set(get_asset_tags(session, asset_info_id=asset_info_id)) - return { - "added": sorted(((after - current) & want)), - "already_present": sorted(want & current), - "total_tags": sorted(after), - } - - -def remove_tags_from_asset_info( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], -) -> dict: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": [], "not_present": [], "total_tags": total} - - existing = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - to_remove = sorted(set(t for t in norm if t in existing)) - not_present = sorted(set(t for t in norm if t not in existing)) - - if to_remove: - session.execute( - delete(AssetInfoTag) - .where( - AssetInfoTag.asset_info_id == asset_info_id, - AssetInfoTag.tag_name.in_(to_remove), - ) - ) - session.flush() - - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": to_remove, "not_present": not_present, "total_tags": total} - - -def remove_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, -) -> None: - session.execute( - sa.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), - AssetInfoTag.tag_name == "missing", - ) - ) - - -def set_asset_info_preview( - session: Session, - *, - asset_info_id: str, - preview_asset_id: str | None = None, -) -> None: - """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - if preview_asset_id is None: - info.preview_id = None - else: - # validate preview asset exists - if not session.get(Asset, preview_asset_id): - raise ValueError(f"Preview Asset {preview_asset_id} not found") - info.preview_id = preview_asset_id - - info.updated_at = utcnow() - session.flush() diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py new file mode 100644 index 000000000..7888d0645 --- /dev/null +++ b/app/assets/database/queries/__init__.py @@ -0,0 +1,121 @@ +from app.assets.database.queries.asset import ( + asset_exists_by_hash, + bulk_insert_assets, + get_asset_by_hash, + get_existing_asset_ids, + reassign_asset_references, + update_asset_hash_and_mime, + upsert_asset, +) +from app.assets.database.queries.asset_reference import ( + CacheStateRow, + UnenrichedReferenceRow, + bulk_insert_references_ignore_conflicts, + bulk_update_enrichment_level, + bulk_update_is_missing, + bulk_update_needs_verify, + convert_metadata_to_rows, + delete_assets_by_ids, + delete_orphaned_seed_asset, + delete_reference_by_id, + delete_references_by_ids, + fetch_reference_and_asset, + fetch_reference_asset_and_tags, + get_or_create_reference, + get_reference_by_file_path, + get_reference_by_id, + get_reference_with_owner_check, + get_reference_ids_by_ids, + get_references_by_paths_and_asset_ids, + get_references_for_prefixes, + get_unenriched_references, + get_unreferenced_unhashed_asset_ids, + insert_reference, + list_references_by_asset_id, + list_references_page, + mark_references_missing_outside_prefixes, + reference_exists_for_asset_id, + restore_references_by_paths, + set_reference_metadata, + set_reference_preview, + soft_delete_reference_by_id, + update_reference_access_time, + update_reference_name, + update_reference_timestamps, + update_reference_updated_at, + upsert_reference, +) +from app.assets.database.queries.tags import ( + AddTagsResult, + RemoveTagsResult, + SetTagsResult, + add_missing_tag_for_asset_id, + add_tags_to_reference, + bulk_insert_tags_and_meta, + ensure_tags_exist, + get_reference_tags, + list_tags_with_usage, + remove_missing_tag_for_asset_id, + remove_tags_from_reference, + set_reference_tags, + validate_tags_exist, +) + +__all__ = [ + "AddTagsResult", + "CacheStateRow", + "RemoveTagsResult", + "SetTagsResult", + "UnenrichedReferenceRow", + "add_missing_tag_for_asset_id", + "add_tags_to_reference", + "asset_exists_by_hash", + "bulk_insert_assets", + "bulk_insert_references_ignore_conflicts", + "bulk_insert_tags_and_meta", + "bulk_update_enrichment_level", + "bulk_update_is_missing", + "bulk_update_needs_verify", + "convert_metadata_to_rows", + "delete_assets_by_ids", + "delete_orphaned_seed_asset", + "delete_reference_by_id", + "delete_references_by_ids", + "ensure_tags_exist", + "fetch_reference_and_asset", + "fetch_reference_asset_and_tags", + "get_asset_by_hash", + "get_existing_asset_ids", + "get_or_create_reference", + "get_reference_by_file_path", + "get_reference_by_id", + "get_reference_with_owner_check", + "get_reference_ids_by_ids", + "get_reference_tags", + "get_references_by_paths_and_asset_ids", + "get_references_for_prefixes", + "get_unenriched_references", + "get_unreferenced_unhashed_asset_ids", + "insert_reference", + "list_references_by_asset_id", + "list_references_page", + "list_tags_with_usage", + "mark_references_missing_outside_prefixes", + "reassign_asset_references", + "reference_exists_for_asset_id", + "remove_missing_tag_for_asset_id", + "remove_tags_from_reference", + "restore_references_by_paths", + "set_reference_metadata", + "set_reference_preview", + "soft_delete_reference_by_id", + "set_reference_tags", + "update_asset_hash_and_mime", + "update_reference_access_time", + "update_reference_name", + "update_reference_timestamps", + "update_reference_updated_at", + "upsert_asset", + "upsert_reference", + "validate_tags_exist", +] diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py new file mode 100644 index 000000000..a21f5b68f --- /dev/null +++ b/app/assets/database/queries/asset.py @@ -0,0 +1,140 @@ +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.dialects import sqlite +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks + + +def asset_exists_by_hash( + session: Session, + asset_hash: str, +) -> bool: + """ + Check if an asset with a given hash exists in database. + """ + row = ( + session.execute( + select(sa.literal(True)) + .select_from(Asset) + .where(Asset.hash == asset_hash) + .limit(1) + ) + ).first() + return row is not None + + +def get_asset_by_hash( + session: Session, + asset_hash: str, +) -> Asset | None: + return ( + (session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))) + .scalars() + .first() + ) + + +def upsert_asset( + session: Session, + asset_hash: str, + size_bytes: int, + mime_type: str | None = None, +) -> tuple[Asset, bool, bool]: + """Upsert an Asset by hash. Returns (asset, created, updated).""" + vals = {"hash": asset_hash, "size_bytes": int(size_bytes)} + if mime_type: + vals["mime_type"] = mime_type + + ins = ( + sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + res = session.execute(ins) + created = int(res.rowcount or 0) > 0 + + asset = ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + .scalars() + .first() + ) + if not asset: + raise RuntimeError("Asset row not found after upsert.") + + updated = False + if not created: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and asset.mime_type != mime_type: + asset.mime_type = mime_type + changed = True + if changed: + updated = True + + return asset, created, updated + + +def bulk_insert_assets( + session: Session, + rows: list[dict], +) -> None: + """Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash.""" + if not rows: + return + ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash]) + for chunk in iter_chunks(rows, calculate_rows_per_statement(5)): + session.execute(ins, chunk) + + +def get_existing_asset_ids( + session: Session, + asset_ids: list[str], +) -> set[str]: + """Return the subset of asset_ids that exist in the database.""" + if not asset_ids: + return set() + found: set[str] = set() + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + rows = session.execute( + select(Asset.id).where(Asset.id.in_(chunk)) + ).fetchall() + found.update(row[0] for row in rows) + return found + + +def update_asset_hash_and_mime( + session: Session, + asset_id: str, + asset_hash: str | None = None, + mime_type: str | None = None, +) -> bool: + """Update asset hash and/or mime_type. Returns True if asset was found.""" + asset = session.get(Asset, asset_id) + if not asset: + return False + if asset_hash is not None: + asset.hash = asset_hash + if mime_type is not None: + asset.mime_type = mime_type + return True + + +def reassign_asset_references( + session: Session, + from_asset_id: str, + to_asset_id: str, + reference_id: str, +) -> None: + """Reassign a reference from one asset to another. + + Used when merging a stub asset into an existing asset with the same hash. + """ + ref = session.get(AssetReference, reference_id) + if ref and ref.asset_id == from_asset_id: + ref.asset_id = to_asset_id + + session.flush() diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py new file mode 100644 index 000000000..6524791cc --- /dev/null +++ b/app/assets/database/queries/asset_reference.py @@ -0,0 +1,1033 @@ +"""Query functions for the unified AssetReference table. + +This module replaces the separate asset_info.py and cache_state.py query modules, +providing a unified interface for the merged asset_references table. +""" + +from collections import defaultdict +from datetime import datetime +from decimal import Decimal +from typing import NamedTuple, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, exists, select +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, noload + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceMeta, + AssetReferenceTag, + Tag, +) +from app.assets.database.queries.common import ( + MAX_BIND_PARAMS, + build_prefix_like_conditions, + build_visible_owner_clause, + calculate_rows_per_statement, + iter_chunks, +) +from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags + + +def _check_is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def _scalar_to_row(key: str, ordinal: int, value) -> dict: + """Convert a scalar value to a typed projection row.""" + if value is None: + return { + "key": key, + "ordinal": ordinal, + "val_str": None, + "val_num": None, + "val_bool": None, + "val_json": None, + } + if isinstance(value, bool): + return {"key": key, "ordinal": ordinal, "val_bool": bool(value)} + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return {"key": key, "ordinal": ordinal, "val_num": num} + if isinstance(value, str): + return {"key": key, "ordinal": ordinal, "val_str": value} + return {"key": key, "ordinal": ordinal, "val_json": value} + + +def convert_metadata_to_rows(key: str, value) -> list[dict]: + """Turn a metadata key/value into typed projection rows.""" + if value is None: + return [_scalar_to_row(key, 0, None)] + + if _check_is_scalar(value): + return [_scalar_to_row(key, 0, value)] + + if isinstance(value, list): + if all(_check_is_scalar(x) for x in value): + return [_scalar_to_row(key, i, x) for i, x in enumerate(value)] + return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)] + + return [{"key": key, "ordinal": 0, "val_json": value}] + + +def _apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def _apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None = None, +) -> sa.sql.Select: + """Apply filters using asset_reference_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetReferenceMeta.val_json.is_(None), + AssetReferenceMeta.val_str.is_(None), + AssetReferenceMeta.val_num.is_(None), + AssetReferenceMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetReferenceMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetReferenceMeta.val_str == value) + return _exists_for_pred(key, AssetReferenceMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def get_reference_by_id( + session: Session, + reference_id: str, +) -> AssetReference | None: + return session.get(AssetReference, reference_id) + + +def get_reference_with_owner_check( + session: Session, + reference_id: str, + owner_id: str, +) -> AssetReference: + """Fetch a reference and verify ownership. + + Raises: + ValueError: if reference not found or soft-deleted + PermissionError: if owner_id doesn't match + """ + ref = get_reference_by_id(session, reference_id=reference_id) + if not ref or ref.deleted_at is not None: + raise ValueError(f"AssetReference {reference_id} not found") + if ref.owner_id and ref.owner_id != owner_id: + raise PermissionError("not owner") + return ref + + +def get_reference_by_file_path( + session: Session, + file_path: str, +) -> AssetReference | None: + """Get a reference by its file path.""" + return ( + session.execute( + select(AssetReference).where(AssetReference.file_path == file_path).limit(1) + ) + .scalars() + .first() + ) + + +def reference_exists_for_asset_id( + session: Session, + asset_id: str, +) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetReference) + .where(AssetReference.asset_id == asset_id) + .where(AssetReference.deleted_at.is_(None)) + .limit(1) + ) + return session.execute(q).first() is not None + + +def insert_reference( + session: Session, + asset_id: str, + name: str, + owner_id: str = "", + file_path: str | None = None, + mtime_ns: int | None = None, + preview_id: str | None = None, +) -> AssetReference | None: + """Insert a new AssetReference. Returns None if unique constraint violated.""" + now = get_utc_now() + try: + with session.begin_nested(): + ref = AssetReference( + asset_id=asset_id, + name=name, + owner_id=owner_id, + file_path=file_path, + mtime_ns=mtime_ns, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + except IntegrityError: + return None + + +def get_or_create_reference( + session: Session, + asset_id: str, + name: str, + owner_id: str = "", + file_path: str | None = None, + mtime_ns: int | None = None, + preview_id: str | None = None, +) -> tuple[AssetReference, bool]: + """Get existing or create new AssetReference. + + For filesystem references (file_path is set), uniqueness is by file_path. + For API references (file_path is None), we look for matching + asset_id + owner_id + name. + + Returns (reference, created). + """ + ref = insert_reference( + session, + asset_id=asset_id, + name=name, + owner_id=owner_id, + file_path=file_path, + mtime_ns=mtime_ns, + preview_id=preview_id, + ) + if ref: + return ref, True + + # Find existing - priority to file_path match, then name match + if file_path: + existing = get_reference_by_file_path(session, file_path) + else: + existing = ( + session.execute( + select(AssetReference) + .where( + AssetReference.asset_id == asset_id, + AssetReference.name == name, + AssetReference.owner_id == owner_id, + AssetReference.file_path.is_(None), + ) + .limit(1) + ) + .unique() + .scalar_one_or_none() + ) + if not existing: + raise RuntimeError("Failed to find AssetReference after insert conflict.") + return existing, False + + +def update_reference_timestamps( + session: Session, + reference: AssetReference, + preview_id: str | None = None, +) -> None: + """Update timestamps and optionally preview_id on existing AssetReference.""" + now = get_utc_now() + if preview_id and reference.preview_id != preview_id: + reference.preview_id = preview_id + reference.updated_at = now + + +def list_references_page( + session: Session, + owner_id: str = "", + limit: int = 100, + offset: int = 0, + name_contains: str | None = None, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + metadata_filter: dict | None = None, + sort: str | None = None, + order: str | None = None, +) -> tuple[list[AssetReference], dict[str, list[str]], int]: + """List references with pagination, filtering, and sorting. + + Returns (references, tag_map, total_count). + """ + base = ( + select(AssetReference) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .options(noload(AssetReference.tags)) + ) + + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) + + base = _apply_tag_filters(base, include_tags, exclude_tags) + base = _apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetReference.name, + "created_at": AssetReference.created_at, + "updated_at": AssetReference.updated_at, + "last_access_time": AssetReference.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetReference.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(sa.func.count()) + .select_from(AssetReference) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + ) + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + count_stmt = count_stmt.where( + AssetReference.name.ilike(f"%{escaped}%", escape=esc) + ) + count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) + + total = int(session.execute(count_stmt).scalar_one() or 0) + refs = session.execute(base).unique().scalars().all() + + id_list: list[str] = [r.id for r in refs] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = session.execute( + select(AssetReferenceTag.asset_reference_id, Tag.name) + .join(Tag, Tag.name == AssetReferenceTag.tag_name) + .where(AssetReferenceTag.asset_reference_id.in_(id_list)) + .order_by(AssetReferenceTag.added_at) + ) + for ref_id, tag_name in rows.all(): + tag_map[ref_id].append(tag_name) + + return list(refs), tag_map, total + + +def fetch_reference_asset_and_tags( + session: Session, + reference_id: str, + owner_id: str = "", +) -> tuple[AssetReference, Asset, list[str]] | None: + stmt = ( + select(AssetReference, Asset, Tag.name) + .join(Asset, Asset.id == AssetReference.asset_id) + .join( + AssetReferenceTag, + AssetReferenceTag.asset_reference_id == AssetReference.id, + isouter=True, + ) + .join(Tag, Tag.name == AssetReferenceTag.tag_name, isouter=True) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .options(noload(AssetReference.tags)) + .order_by(Tag.name.asc()) + ) + + rows = session.execute(stmt).all() + if not rows: + return None + + first_ref, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _ref, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_ref, first_asset, tags + + +def fetch_reference_and_asset( + session: Session, + reference_id: str, + owner_id: str = "", +) -> tuple[AssetReference, Asset] | None: + stmt = ( + select(AssetReference, Asset) + .join(Asset, Asset.id == AssetReference.asset_id) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetReference.tags)) + ) + pair = session.execute(stmt).first() + if not pair: + return None + return pair[0], pair[1] + + +def update_reference_access_time( + session: Session, + reference_id: str, + ts: datetime | None = None, + only_if_newer: bool = True, +) -> None: + ts = ts or get_utc_now() + stmt = sa.update(AssetReference).where(AssetReference.id == reference_id) + if only_if_newer: + stmt = stmt.where( + sa.or_( + AssetReference.last_access_time.is_(None), + AssetReference.last_access_time < ts, + ) + ) + session.execute(stmt.values(last_access_time=ts)) + + +def update_reference_name( + session: Session, + reference_id: str, + name: str, +) -> None: + """Update the name of an AssetReference.""" + now = get_utc_now() + session.execute( + sa.update(AssetReference) + .where(AssetReference.id == reference_id) + .values(name=name, updated_at=now) + ) + + +def update_reference_updated_at( + session: Session, + reference_id: str, + ts: datetime | None = None, +) -> None: + """Update the updated_at timestamp of an AssetReference.""" + ts = ts or get_utc_now() + session.execute( + sa.update(AssetReference) + .where(AssetReference.id == reference_id) + .values(updated_at=ts) + ) + + +def set_reference_metadata( + session: Session, + reference_id: str, + user_metadata: dict | None = None, +) -> None: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + ref.user_metadata = user_metadata or {} + ref.updated_at = get_utc_now() + session.flush() + + session.execute( + delete(AssetReferenceMeta).where( + AssetReferenceMeta.asset_reference_id == reference_id + ) + ) + session.flush() + + if not user_metadata: + return + + rows: list[AssetReferenceMeta] = [] + for k, v in user_metadata.items(): + for r in convert_metadata_to_rows(k, v): + rows.append( + AssetReferenceMeta( + asset_reference_id=reference_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + session.flush() + + +def delete_reference_by_id( + session: Session, + reference_id: str, + owner_id: str, +) -> bool: + stmt = sa.delete(AssetReference).where( + AssetReference.id == reference_id, + build_visible_owner_clause(owner_id), + ) + return int(session.execute(stmt).rowcount or 0) > 0 + + +def soft_delete_reference_by_id( + session: Session, + reference_id: str, + owner_id: str, +) -> bool: + """Mark a reference as soft-deleted by setting deleted_at timestamp. + + Returns True if the reference was found and marked deleted. + """ + now = get_utc_now() + stmt = ( + sa.update(AssetReference) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .values(deleted_at=now) + ) + return int(session.execute(stmt).rowcount or 0) > 0 + + +def set_reference_preview( + session: Session, + reference_id: str, + preview_asset_id: str | None = None, +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + if preview_asset_id is None: + ref.preview_id = None + else: + if not session.get(Asset, preview_asset_id): + raise ValueError(f"Preview Asset {preview_asset_id} not found") + ref.preview_id = preview_asset_id + + ref.updated_at = get_utc_now() + session.flush() + + +class CacheStateRow(NamedTuple): + """Row from reference query with cache state data.""" + + reference_id: str + file_path: str + mtime_ns: int | None + needs_verify: bool + asset_id: str + asset_hash: str | None + size_bytes: int | None + + +def list_references_by_asset_id( + session: Session, + asset_id: str, +) -> Sequence[AssetReference]: + return ( + session.execute( + select(AssetReference) + .where(AssetReference.asset_id == asset_id) + .order_by(AssetReference.id.asc()) + ) + .scalars() + .all() + ) + + +def upsert_reference( + session: Session, + asset_id: str, + file_path: str, + name: str, + mtime_ns: int, + owner_id: str = "", +) -> tuple[bool, bool]: + """Upsert a reference by file_path. Returns (created, updated). + + Also restores references that were previously marked as missing. + """ + now = get_utc_now() + vals = { + "asset_id": asset_id, + "file_path": file_path, + "name": name, + "owner_id": owner_id, + "mtime_ns": int(mtime_ns), + "is_missing": False, + "created_at": now, + "updated_at": now, + "last_access_time": now, + } + ins = ( + sqlite.insert(AssetReference) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetReference.file_path]) + ) + res = session.execute(ins) + created = int(res.rowcount or 0) > 0 + + if created: + return True, False + + upd = ( + sa.update(AssetReference) + .where(AssetReference.file_path == file_path) + .where( + sa.or_( + AssetReference.asset_id != asset_id, + AssetReference.mtime_ns.is_(None), + AssetReference.mtime_ns != int(mtime_ns), + AssetReference.is_missing == True, # noqa: E712 + AssetReference.deleted_at.isnot(None), + ) + ) + .values( + asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, + deleted_at=None, updated_at=now, + ) + ) + res2 = session.execute(upd) + updated = int(res2.rowcount or 0) > 0 + return False, updated + + +def mark_references_missing_outside_prefixes( + session: Session, + valid_prefixes: list[str], +) -> int: + """Mark references as missing when file_path doesn't match any valid prefix. + + Returns number of references marked as missing. + """ + if not valid_prefixes: + return 0 + + conds = build_prefix_like_conditions(valid_prefixes) + matches_valid_prefix = sa.or_(*conds) + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(~matches_valid_prefix) + .where(AssetReference.is_missing == False) # noqa: E712 + .values(is_missing=True) + ) + return result.rowcount + + +def restore_references_by_paths(session: Session, file_paths: list[str]) -> int: + """Restore references that were previously marked as missing. + + Returns number of references restored. + """ + if not file_paths: + return 0 + + total = 0 + for chunk in iter_chunks(file_paths, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.in_(chunk)) + .where(AssetReference.is_missing == True) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .values(is_missing=False) + ) + total += result.rowcount + return total + + +def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]: + """Get IDs of unhashed assets (hash=None) with no active references. + + An asset is considered unreferenced if it has no references, + or all its references are marked as missing. + + Returns list of asset IDs that are unreferenced. + """ + active_ref_exists = ( + sa.select(sa.literal(1)) + .where(AssetReference.asset_id == Asset.id) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .correlate(Asset) + .exists() + ) + unreferenced_subq = sa.select(Asset.id).where( + Asset.hash.is_(None), ~active_ref_exists + ) + return [row[0] for row in session.execute(unreferenced_subq).all()] + + +def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int: + """Delete assets and their references by ID. + + Returns number of assets deleted. + """ + if not asset_ids: + return 0 + total = 0 + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id.in_(chunk)) + ) + result = session.execute(sa.delete(Asset).where(Asset.id.in_(chunk))) + total += result.rowcount + return total + + +def get_references_for_prefixes( + session: Session, + prefixes: list[str], + *, + include_missing: bool = False, +) -> list[CacheStateRow]: + """Get all references with file paths matching any of the given prefixes. + + Args: + session: Database session + prefixes: List of absolute directory prefixes to match + include_missing: If False (default), exclude references marked as missing + + Returns: + List of cache state rows with joined asset data + """ + if not prefixes: + return [] + + conds = build_prefix_like_conditions(prefixes) + + query = ( + sa.select( + AssetReference.id, + AssetReference.file_path, + AssetReference.mtime_ns, + AssetReference.needs_verify, + AssetReference.asset_id, + Asset.hash, + Asset.size_bytes, + ) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(sa.or_(*conds)) + ) + + if not include_missing: + query = query.where(AssetReference.is_missing == False) # noqa: E712 + + rows = session.execute( + query.order_by(AssetReference.asset_id.asc(), AssetReference.id.asc()) + ).all() + + return [ + CacheStateRow( + reference_id=row[0], + file_path=row[1], + mtime_ns=row[2], + needs_verify=row[3], + asset_id=row[4], + asset_hash=row[5], + size_bytes=int(row[6]) if row[6] is not None else None, + ) + for row in rows + ] + + +def bulk_update_needs_verify( + session: Session, reference_ids: list[str], value: bool +) -> int: + """Set needs_verify flag for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(chunk)) + .values(needs_verify=value) + ) + total += result.rowcount + return total + + +def bulk_update_is_missing( + session: Session, reference_ids: list[str], value: bool +) -> int: + """Set is_missing flag for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(chunk)) + .values(is_missing=value) + ) + total += result.rowcount + return total + + +def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int: + """Delete references by their IDs. + + Returns: Number of rows deleted + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.delete(AssetReference).where(AssetReference.id.in_(chunk)) + ) + total += result.rowcount + return total + + +def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool: + """Delete a seed asset (hash is None) and its references. + + Returns: True if asset was deleted, False if not found or has a hash + """ + asset = session.get(Asset, asset_id) + if not asset: + return False + if asset.hash is not None: + return False + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id == asset_id) + ) + session.delete(asset) + return True + + +class UnenrichedReferenceRow(NamedTuple): + """Row for references needing enrichment.""" + + reference_id: str + asset_id: str + file_path: str + enrichment_level: int + + +def get_unenriched_references( + session: Session, + prefixes: list[str], + max_level: int = 0, + limit: int = 1000, +) -> list[UnenrichedReferenceRow]: + """Get references that need enrichment (enrichment_level <= max_level). + + Args: + session: Database session + prefixes: List of absolute directory prefixes to scan + max_level: Maximum enrichment level to include (0=stubs, 1=metadata done) + limit: Maximum number of rows to return + + Returns: + List of unenriched reference rows with file paths + """ + if not prefixes: + return [] + + conds = build_prefix_like_conditions(prefixes) + + query = ( + sa.select( + AssetReference.id, + AssetReference.asset_id, + AssetReference.file_path, + AssetReference.enrichment_level, + ) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(sa.or_(*conds)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.enrichment_level <= max_level) + .order_by(AssetReference.id.asc()) + .limit(limit) + ) + + rows = session.execute(query).all() + return [ + UnenrichedReferenceRow( + reference_id=row[0], + asset_id=row[1], + file_path=row[2], + enrichment_level=row[3], + ) + for row in rows + ] + + +def bulk_update_enrichment_level( + session: Session, + reference_ids: list[str], + level: int, +) -> int: + """Update enrichment level for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(reference_ids)) + .values(enrichment_level=level) + ) + return result.rowcount + + +def bulk_insert_references_ignore_conflicts( + session: Session, + rows: list[dict], +) -> None: + """Bulk insert reference rows with ON CONFLICT DO NOTHING on file_path. + + Each dict should have: id, asset_id, file_path, name, owner_id, mtime_ns, etc. + The is_missing field is automatically set to False for new inserts. + """ + if not rows: + return + enriched_rows = [{**row, "is_missing": False} for row in rows] + ins = sqlite.insert(AssetReference).on_conflict_do_nothing( + index_elements=[AssetReference.file_path] + ) + for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(14)): + session.execute(ins, chunk) + + +def get_references_by_paths_and_asset_ids( + session: Session, + path_to_asset: dict[str, str], +) -> set[str]: + """Query references to find paths where our asset_id won the insert. + + Args: + path_to_asset: Mapping of file_path -> asset_id we tried to insert + + Returns: + Set of file_paths where our asset_id is present + """ + if not path_to_asset: + return set() + + pairs = list(path_to_asset.items()) + winners: set[str] = set() + + # Each pair uses 2 bind params, so chunk at MAX_BIND_PARAMS // 2 + for chunk in iter_chunks(pairs, MAX_BIND_PARAMS // 2): + pairwise = sa.tuple_(AssetReference.file_path, AssetReference.asset_id).in_( + chunk + ) + result = session.execute( + select(AssetReference.file_path).where(pairwise) + ) + winners.update(result.scalars().all()) + + return winners + + +def get_reference_ids_by_ids( + session: Session, + reference_ids: list[str], +) -> set[str]: + """Query to find which reference IDs exist in the database.""" + if not reference_ids: + return set() + + found: set[str] = set() + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + select(AssetReference.id).where(AssetReference.id.in_(chunk)) + ) + found.update(result.scalars().all()) + return found diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py new file mode 100644 index 000000000..194c39a1e --- /dev/null +++ b/app/assets/database/queries/common.py @@ -0,0 +1,54 @@ +"""Shared utilities for database query modules.""" + +import os +from typing import Iterable + +import sqlalchemy as sa + +from app.assets.database.models import AssetReference +from app.assets.helpers import escape_sql_like_string + +MAX_BIND_PARAMS = 800 + + +def calculate_rows_per_statement(cols: int) -> int: + """Calculate how many rows can fit in one statement given column count.""" + return max(1, MAX_BIND_PARAMS // max(1, cols)) + + +def iter_chunks(seq, n: int): + """Yield successive n-sized chunks from seq.""" + for i in range(0, len(seq), n): + yield seq[i : i + n] + + +def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]: + """Yield chunks of rows sized to fit within bind param limits.""" + if not rows: + return + yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row)) + + +def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. + + Owner-less rows are visible to everyone. + """ + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetReference.owner_id == "" + return AssetReference.owner_id.in_(["", owner_id]) + + +def build_prefix_like_conditions( + prefixes: list[str], +) -> list[sa.sql.ColumnElement]: + """Build LIKE conditions for matching file paths under directory prefixes.""" + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_sql_like_string(base) + conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) + return conds diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py new file mode 100644 index 000000000..8b25fee67 --- /dev/null +++ b/app/assets/database/queries/tags.py @@ -0,0 +1,356 @@ +from dataclasses import dataclass +from typing import Iterable, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, func, select +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from app.assets.database.models import ( + AssetReference, + AssetReferenceMeta, + AssetReferenceTag, + Tag, +) +from app.assets.database.queries.common import ( + build_visible_owner_clause, + iter_row_chunks, +) +from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags + + +@dataclass(frozen=True) +class AddTagsResult: + added: list[str] + already_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class RemoveTagsResult: + removed: list[str] + not_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class SetTagsResult: + added: list[str] + removed: list[str] + total: list[str] + + +def validate_tags_exist(session: Session, tags: list[str]) -> None: + """Raise ValueError if any of the given tag names do not exist.""" + existing_tag_names = set( + name + for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all() + ) + missing = [t for t in tags if t not in existing_tag_names] + if missing: + raise ValueError(f"Unknown tags: {missing}") + + +def ensure_tags_exist( + session: Session, names: Iterable[str], tag_type: str = "user" +) -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + session.execute(ins) + + +def get_reference_tags(session: Session, reference_id: str) -> list[str]: + return [ + tag_name + for (tag_name,) in ( + session.execute( + select(AssetReferenceTag.tag_name).where( + AssetReferenceTag.asset_reference_id == reference_id + ) + ) + ).all() + ] + + +def set_reference_tags( + session: Session, + reference_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> SetTagsResult: + desired = normalize_tags(tags) + + current = set(get_reference_tags(session, reference_id)) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + ensure_tags_exist(session, to_add, tag_type="user") + session.add_all( + [ + AssetReferenceTag( + asset_reference_id=reference_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) + session.flush() + + if to_remove: + session.execute( + delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id == reference_id, + AssetReferenceTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + return SetTagsResult(added=to_add, removed=to_remove, total=desired) + + +def add_tags_to_reference( + session: Session, + reference_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + reference_row: AssetReference | None = None, +) -> AddTagsResult: + if not reference_row: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_reference_tags(session, reference_id=reference_id) + return AddTagsResult(added=[], already_present=[], total_tags=total) + + if create_if_missing: + ensure_tags_exist(session, norm, tag_type="user") + + current = set(get_reference_tags(session, reference_id)) + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetReferenceTag( + asset_reference_id=reference_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) + session.flush() + except IntegrityError: + nested.rollback() + + after = set(get_reference_tags(session, reference_id=reference_id)) + return AddTagsResult( + added=sorted(((after - current) & want)), + already_present=sorted(want & current), + total_tags=sorted(after), + ) + + +def remove_tags_from_reference( + session: Session, + reference_id: str, + tags: Sequence[str], +) -> RemoveTagsResult: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_reference_tags(session, reference_id=reference_id) + return RemoveTagsResult(removed=[], not_present=[], total_tags=total) + + existing = set(get_reference_tags(session, reference_id)) + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + session.execute( + delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id == reference_id, + AssetReferenceTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + total = get_reference_tags(session, reference_id=reference_id) + return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total) + + +def add_missing_tag_for_asset_id( + session: Session, + asset_id: str, + origin: str = "automatic", +) -> None: + select_rows = ( + sa.select( + AssetReference.id.label("asset_reference_id"), + sa.literal("missing").label("tag_name"), + sa.literal(origin).label("origin"), + sa.literal(get_utc_now()).label("added_at"), + ) + .where(AssetReference.asset_id == asset_id) + .where( + sa.not_( + sa.exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == "missing") + ) + ) + ) + ) + session.execute( + sqlite.insert(AssetReferenceTag) + .from_select( + ["asset_reference_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing( + index_elements=[ + AssetReferenceTag.asset_reference_id, + AssetReferenceTag.tag_name, + ] + ) + ) + + +def remove_missing_tag_for_asset_id( + session: Session, + asset_id: str, +) -> None: + session.execute( + sa.delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id.in_( + sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id) + ), + AssetReferenceTag.tag_name == "missing", + ) + ) + + +def list_tags_with_usage( + session: Session, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetReferenceTag.tag_name.label("tag_name"), + func.count(AssetReferenceTag.asset_reference_id).label("cnt"), + ) + .select_from(AssetReferenceTag) + .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.deleted_at.is_(None)) + .group_by(AssetReferenceTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + escaped, esc = escape_sql_like_string(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + escaped, esc = escape_sql_like_string(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if not include_zero: + visible_tags_sq = ( + select(AssetReferenceTag.tag_name) + .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.deleted_at.is_(None)) + .group_by(AssetReferenceTag.tag_name) + ) + total_q = total_q.where(Tag.name.in_(visible_tags_sq)) + + rows = (session.execute(q.limit(limit).offset(offset))).all() + total = (session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +def bulk_insert_tags_and_meta( + session: Session, + tag_rows: list[dict], + meta_rows: list[dict], +) -> None: + """Batch insert into asset_reference_tags and asset_reference_meta. + + Uses ON CONFLICT DO NOTHING. + + Args: + session: Database session + tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at + meta_rows: Dicts with: asset_reference_id, key, ordinal, val_* + """ + if tag_rows: + ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing( + index_elements=[ + AssetReferenceTag.asset_reference_id, + AssetReferenceTag.tag_name, + ] + ) + for chunk in iter_row_chunks(tag_rows, cols_per_row=4): + session.execute(ins_tags, chunk) + + if meta_rows: + ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing( + index_elements=[ + AssetReferenceMeta.asset_reference_id, + AssetReferenceMeta.key, + AssetReferenceMeta.ordinal, + ] + ) + for chunk in iter_row_chunks(meta_rows, cols_per_row=7): + session.execute(ins_meta, chunk) diff --git a/app/assets/database/tags.py b/app/assets/database/tags.py deleted file mode 100644 index 3ab6497c2..000000000 --- a/app/assets/database/tags.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Iterable - -import sqlalchemy -from sqlalchemy.orm import Session -from sqlalchemy.dialects import sqlite - -from app.assets.helpers import normalize_tags, utcnow -from app.assets.database.models import Tag, AssetInfoTag, AssetInfo - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: - wanted = normalize_tags(list(names)) - if not wanted: - return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] - ins = ( - sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - return session.execute(ins) - -def add_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, - origin: str = "automatic", -) -> None: - select_rows = ( - sqlalchemy.select( - AssetInfo.id.label("asset_info_id"), - sqlalchemy.literal("missing").label("tag_name"), - sqlalchemy.literal(origin).label("origin"), - sqlalchemy.literal(utcnow()).label("added_at"), - ) - .where(AssetInfo.asset_id == asset_id) - .where( - sqlalchemy.not_( - sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) - ) - ) - ) - session.execute( - sqlite.insert(AssetInfoTag) - .from_select( - ["asset_info_id", "tag_name", "origin", "added_at"], - select_rows, - ) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - -def remove_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, -) -> None: - session.execute( - sqlalchemy.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), - AssetInfoTag.tag_name == "missing", - ) - ) diff --git a/app/assets/hashing.py b/app/assets/hashing.py deleted file mode 100644 index 4b72084b9..000000000 --- a/app/assets/hashing.py +++ /dev/null @@ -1,75 +0,0 @@ -from blake3 import blake3 -from typing import IO -import os -import asyncio - - -DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB - -# NOTE: this allows hashing different representations of a file-like object -def blake3_hash( - fp: str | IO[bytes], - chunk_size: int = DEFAULT_CHUNK, -) -> str: - """ - Returns a BLAKE3 hex digest for ``fp``, which may be: - - a filename (str/bytes) or PathLike - - an open binary file object - If ``fp`` is a file object, it must be opened in **binary** mode and support - ``read``, ``seek``, and ``tell``. The function will seek to the start before - reading and will attempt to restore the original position afterward. - """ - # duck typing to check if input is a file-like object - if hasattr(fp, "read"): - return _hash_file_obj(fp, chunk_size) - - with open(os.fspath(fp), "rb") as f: - return _hash_file_obj(f, chunk_size) - - -async def blake3_hash_async( - fp: str | IO[bytes], - chunk_size: int = DEFAULT_CHUNK, -) -> str: - """Async wrapper for ``blake3_hash_sync``. - Uses a worker thread so the event loop remains responsive. - """ - # If it is a path, open inside the worker thread to keep I/O off the loop. - if hasattr(fp, "read"): - return await asyncio.to_thread(blake3_hash, fp, chunk_size) - - def _worker() -> str: - with open(os.fspath(fp), "rb") as f: - return _hash_file_obj(f, chunk_size) - - return await asyncio.to_thread(_worker) - - -def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str: - """ - Hash an already-open binary file object by streaming in chunks. - - Seeks to the beginning before reading (if supported). - - Restores the original position afterward (if tell/seek are supported). - """ - if chunk_size <= 0: - chunk_size = DEFAULT_CHUNK - - # in case file object is already open and not at the beginning, track so can be restored after hashing - orig_pos = file_obj.tell() - - try: - # seek to the beginning before reading - if orig_pos != 0: - file_obj.seek(0) - - h = blake3() - while True: - chunk = file_obj.read(chunk_size) - if not chunk: - break - h.update(chunk) - return h.hexdigest() - finally: - # restore original position in file object, if needed - if orig_pos != 0: - file_obj.seek(orig_pos) diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 5030b123a..3798f3933 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -1,226 +1,42 @@ -import contextlib import os -from decimal import Decimal -from aiohttp import web from datetime import datetime, timezone -from pathlib import Path -from typing import Literal, Any - -import folder_paths +from typing import Sequence -RootType = Literal["models", "input", "output"] -ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") - -def get_query_dict(request: web.Request) -> dict[str, Any]: +def select_best_live_path(states: Sequence) -> str: """ - Gets a dictionary of query parameters from the request. - - 'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic. + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. """ - query_dict = { - key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key) - for key in request.query.keys() - } - return query_dict + alive = [ + s + for s in states + if getattr(s, "file_path", None) and os.path.isfile(s.file_path) + ] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path -def list_tree(base_dir: str) -> list[str]: - out: list[str] = [] - base_abs = os.path.abspath(base_dir) - if not os.path.isdir(base_abs): - return out - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - for name in filenames: - out.append(os.path.abspath(os.path.join(dirpath, name))) - return out -def prefixes_for_root(root: RootType) -> list[str]: - if root == "models": - bases: list[str] = [] - for _bucket, paths in get_comfy_models_folders(): - bases.extend(paths) - return [os.path.abspath(p) for p in bases] - if root == "input": - return [os.path.abspath(folder_paths.get_input_directory())] - if root == "output": - return [os.path.abspath(folder_paths.get_output_directory())] - return [] +def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char in a LIKE prefix. -def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: - """Escapes %, _ and the escape char itself in a LIKE prefix. - Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like(). + Returns (escaped_prefix, escape_char). """ s = s.replace(escape, escape + escape) # escape the escape char first s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards return s, escape -def fast_asset_file_check( - *, - mtime_db: int | None, - size_db: int | None, - stat_result: os.stat_result, -) -> bool: - if mtime_db is None: - return False - actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)) - if int(mtime_db) != int(actual_mtime_ns): - return False - sz = int(size_db or 0) - if sz > 0: - return int(stat_result.st_size) == sz - return True -def utcnow() -> datetime: +def get_utc_now() -> datetime: """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" return datetime.now(timezone.utc).replace(tzinfo=None) -def get_comfy_models_folders() -> list[tuple[str, list[str]]]: - """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. - - We trust `folder_paths.folder_names_and_paths` and include a category if - *any* of its base paths lies under the Comfy `models_dir`. - """ - targets: list[tuple[str, list[str]]] = [] - models_root = os.path.abspath(folder_paths.models_dir) - for name, values in folder_paths.folder_names_and_paths.items(): - paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI - if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): - targets.append((name, paths)) - return targets - -def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: - """Validates and maps tags -> (base_dir, subdirs_for_fs)""" - root = tags[0] - if root == "models": - if len(tags) < 2: - raise ValueError("at least two tags required for model asset") - try: - bases = folder_paths.folder_names_and_paths[tags[1]][0] - except KeyError: - raise ValueError(f"unknown model category '{tags[1]}'") - if not bases: - raise ValueError(f"no base path configured for category '{tags[1]}'") - base_dir = os.path.abspath(bases[0]) - raw_subdirs = tags[2:] - else: - base_dir = os.path.abspath( - folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory() - ) - raw_subdirs = tags[1:] - for i in raw_subdirs: - if i in (".", ".."): - raise ValueError("invalid path component in tags") - - return base_dir, raw_subdirs if raw_subdirs else [] - -def ensure_within_base(candidate: str, base: str) -> None: - cand_abs = os.path.abspath(candidate) - base_abs = os.path.abspath(base) - try: - if os.path.commonpath([cand_abs, base_abs]) != base_abs: - raise ValueError("destination escapes base directory") - except Exception: - raise ValueError("invalid destination path") - -def compute_relative_filename(file_path: str) -> str | None: - """ - Return the model's path relative to the last well-known folder (the model category), - using forward slashes, eg: - /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" - /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" - - For non-model paths, returns None. - NOTE: this is a temporary helper, used only for initializing metadata["filename"] field. - """ - try: - root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path) - except ValueError: - return None - - p = Path(rel_path) - parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] - if not parts: - return None - - if root_category == "models": - # parts[0] is the category ("checkpoints", "vae", etc) – drop it - inside = parts[1:] if len(parts) > 1 else [parts[0]] - return "/".join(inside) - return "/".join(parts) # input/output: keep all parts - -def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: - """Given an absolute or relative file path, determine which root category the path belongs to: - - 'input' if the file resides under `folder_paths.get_input_directory()` - - 'output' if the file resides under `folder_paths.get_output_directory()` - - 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()` - - Returns: - (root_category, relative_path_inside_that_root) - For 'models', the relative path is prefixed with the category name: - e.g. ('models', 'vae/test/sub/ae.safetensors') - - Raises: - ValueError: if the path does not belong to input, output, or configured model bases. - """ - fp_abs = os.path.abspath(file_path) - - def _is_within(child: str, parent: str) -> bool: - try: - return os.path.commonpath([child, parent]) == parent - except Exception: - return False - - def _rel(child: str, parent: str) -> str: - return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep) - - # 1) input - input_base = os.path.abspath(folder_paths.get_input_directory()) - if _is_within(fp_abs, input_base): - return "input", _rel(fp_abs, input_base) - - # 2) output - output_base = os.path.abspath(folder_paths.get_output_directory()) - if _is_within(fp_abs, output_base): - return "output", _rel(fp_abs, output_base) - - # 3) models (check deepest matching base to avoid ambiguity) - best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) - for bucket, bases in get_comfy_models_folders(): - for b in bases: - base_abs = os.path.abspath(b) - if not _is_within(fp_abs, base_abs): - continue - cand = (len(base_abs), bucket, _rel(fp_abs, base_abs)) - if best is None or cand[0] > best[0]: - best = cand - - if best is not None: - _, bucket, rel_inside = best - combined = os.path.join(bucket, rel_inside) - return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) - - raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}") - -def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: - """Return a tuple (name, tags) derived from a filesystem path. - - Semantics: - - Root category is determined by `get_relative_to_root_category_path_of_asset`. - - The returned `name` is the base filename with extension from the relative path. - - The returned `tags` are: - [root_category] + parent folders of the relative path (in order) - For 'models', this means: - file '/.../ModelsDir/vae/test_tag/ae.safetensors' - -> root_category='models', some_path='vae/test_tag/ae.safetensors' - -> name='ae.safetensors', tags=['models', 'vae', 'test_tag'] - - Raises: - ValueError: if the path does not belong to input, output, or configured model bases. - """ - root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) - p = Path(some_path) - parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] - return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) def normalize_tags(tags: list[str] | None) -> list[str]: """ @@ -228,85 +44,22 @@ def normalize_tags(tags: list[str] | None) -> list[str]: - Stripping whitespace and converting to lowercase. - Removing duplicates. """ - return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip())) -def collect_models_files() -> list[str]: - out: list[str] = [] - for folder_name, bases in get_comfy_models_folders(): - rel_files = folder_paths.get_filename_list(folder_name) or [] - for rel_path in rel_files: - abs_path = folder_paths.get_full_path(folder_name, rel_path) - if not abs_path: - continue - abs_path = os.path.abspath(abs_path) - allowed = False - for b in bases: - base_abs = os.path.abspath(b) - with contextlib.suppress(Exception): - if os.path.commonpath([abs_path, base_abs]) == base_abs: - allowed = True - break - if allowed: - out.append(abs_path) - return out -def is_scalar(v): - if v is None: - return True - if isinstance(v, bool): - return True - if isinstance(v, (int, float, Decimal, str)): - return True - return False +def validate_blake3_hash(s: str) -> str: + """Validate and normalize a blake3 hash string. -def project_kv(key: str, value): + Returns canonical 'blake3:' or raises ValueError. """ - Turn a metadata key/value into typed projection rows. - Returns list[dict] with keys: - key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) - """ - rows: list[dict] = [] - - def _null_row(ordinal: int) -> dict: - return { - "key": key, "ordinal": ordinal, - "val_str": None, "val_num": None, "val_bool": None, "val_json": None - } - - if value is None: - rows.append(_null_row(0)) - return rows - - if is_scalar(value): - if isinstance(value, bool): - rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) - elif isinstance(value, (int, float, Decimal)): - num = value if isinstance(value, Decimal) else Decimal(str(value)) - rows.append({"key": key, "ordinal": 0, "val_num": num}) - elif isinstance(value, str): - rows.append({"key": key, "ordinal": 0, "val_str": value}) - else: - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows - - if isinstance(value, list): - if all(is_scalar(x) for x in value): - for i, x in enumerate(value): - if x is None: - rows.append(_null_row(i)) - elif isinstance(x, bool): - rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) - elif isinstance(x, (int, float, Decimal)): - num = x if isinstance(x, Decimal) else Decimal(str(x)) - rows.append({"key": key, "ordinal": i, "val_num": num}) - elif isinstance(x, str): - rows.append({"key": key, "ordinal": i, "val_str": x}) - else: - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - for i, x in enumerate(value): - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows + s = s.strip().lower() + if not s or ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if ( + algo != "blake3" + or len(digest) != 64 + or any(c for c in digest if c not in "0123456789abcdef") + ): + raise ValueError("hash must be 'blake3:'") + return f"{algo}:{digest}" diff --git a/app/assets/manager.py b/app/assets/manager.py deleted file mode 100644 index a68c8c8ae..000000000 --- a/app/assets/manager.py +++ /dev/null @@ -1,516 +0,0 @@ -import os -import mimetypes -import contextlib -from typing import Sequence - -from app.database.db import create_session -from app.assets.api import schemas_out, schemas_in -from app.assets.database.queries import ( - asset_exists_by_hash, - asset_info_exists_for_asset_id, - get_asset_by_hash, - get_asset_info_by_id, - fetch_asset_info_asset_and_tags, - fetch_asset_info_and_asset, - create_asset_info_for_existing_asset, - touch_asset_info_by_id, - update_asset_info_full, - delete_asset_info_by_id, - list_cache_states_by_asset_id, - list_asset_infos_page, - list_tags_with_usage, - get_asset_tags, - add_tags_to_asset_info, - remove_tags_from_asset_info, - pick_best_live_path, - ingest_fs_asset, - set_asset_info_preview, -) -from app.assets.helpers import resolve_destination_from_tags, ensure_within_base -from app.assets.database.models import Asset - - -def _safe_sort_field(requested: str | None) -> str: - if not requested: - return "created_at" - v = requested.lower() - if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: - return v - return "created_at" - - -def _get_size_mtime_ns(path: str) -> tuple[int, int]: - st = os.stat(path, follow_symlinks=True) - return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) - - -def _safe_filename(name: str | None, fallback: str) -> str: - n = os.path.basename((name or "").strip() or fallback) - if n: - return n - return fallback - - -def asset_exists(*, asset_hash: str) -> bool: - """ - Check if an asset with a given hash exists in database. - """ - with create_session() as session: - return asset_exists_by_hash(session, asset_hash=asset_hash) - - -def list_assets( - *, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, - name_contains: str | None = None, - metadata_filter: dict | None = None, - limit: int = 20, - offset: int = 0, - sort: str = "created_at", - order: str = "desc", - owner_id: str = "", -) -> schemas_out.AssetsList: - sort = _safe_sort_field(sort) - order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() - - with create_session() as session: - infos, tag_map, total = list_asset_infos_page( - session, - owner_id=owner_id, - include_tags=include_tags, - exclude_tags=exclude_tags, - name_contains=name_contains, - metadata_filter=metadata_filter, - limit=limit, - offset=offset, - sort=sort, - order=order, - ) - - summaries: list[schemas_out.AssetSummary] = [] - for info in infos: - asset = info.asset - tags = tag_map.get(info.id, []) - summaries.append( - schemas_out.AssetSummary( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset else None, - mime_type=asset.mime_type if asset else None, - tags=tags, - created_at=info.created_at, - updated_at=info.updated_at, - last_access_time=info.last_access_time, - ) - ) - - return schemas_out.AssetsList( - assets=summaries, - total=total, - has_more=(offset + len(summaries)) < total, - ) - - -def get_asset( - *, - asset_info_id: str, - owner_id: str = "", -) -> schemas_out.AssetDetail: - with create_session() as session: - res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not res: - raise ValueError(f"AssetInfo {asset_info_id} not found") - info, asset, tag_names = res - preview_id = info.preview_id - - return schemas_out.AssetDetail( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, - mime_type=asset.mime_type if asset else None, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - ) - - -def resolve_asset_content_for_download( - *, - asset_info_id: str, - owner_id: str = "", -) -> tuple[str, str, str]: - with create_session() as session: - pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not pair: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - info, asset = pair - states = list_cache_states_by_asset_id(session, asset_id=asset.id) - abs_path = pick_best_live_path(states) - if not abs_path: - raise FileNotFoundError - - touch_asset_info_by_id(session, asset_info_id=asset_info_id) - session.commit() - - ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" - download_name = info.name or os.path.basename(abs_path) - return abs_path, ctype, download_name - - -def upload_asset_from_temp_path( - spec: schemas_in.UploadAssetSpec, - *, - temp_path: str, - client_filename: str | None = None, - owner_id: str = "", - expected_asset_hash: str | None = None, -) -> schemas_out.AssetCreated: - """ - Create new asset or update existing asset from a temporary file path. - """ - try: - # NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment - import app.assets.hashing as hashing - digest = hashing.blake3_hash(temp_path) - except Exception as e: - raise RuntimeError(f"failed to hash uploaded file: {e}") - asset_hash = "blake3:" + digest - - if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): - raise ValueError("HASH_MISMATCH") - - with create_session() as session: - existing = get_asset_by_hash(session, asset_hash=asset_hash) - if existing is not None: - with contextlib.suppress(Exception): - if temp_path and os.path.exists(temp_path): - os.remove(temp_path) - - display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) - info = create_asset_info_for_existing_asset( - session, - asset_hash=asset_hash, - name=display_name, - user_metadata=spec.user_metadata or {}, - tags=spec.tags or [], - tag_origin="manual", - owner_id=owner_id, - ) - tag_names = get_asset_tags(session, asset_info_id=info.id) - session.commit() - - return schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=existing.hash, - size=int(existing.size_bytes) if existing.size_bytes is not None else None, - mime_type=existing.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=False, - ) - - base_dir, subdirs = resolve_destination_from_tags(spec.tags) - dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir - os.makedirs(dest_dir, exist_ok=True) - - src_for_ext = (client_filename or spec.name or "").strip() - _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" - ext = _ext if 0 < len(_ext) <= 16 else "" - hashed_basename = f"{digest}{ext}" - dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) - ensure_within_base(dest_abs, base_dir) - - content_type = ( - mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] - or mimetypes.guess_type(hashed_basename, strict=False)[0] - or "application/octet-stream" - ) - - try: - os.replace(temp_path, dest_abs) - except Exception as e: - raise RuntimeError(f"failed to move uploaded file into place: {e}") - - try: - size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) - except OSError as e: - raise RuntimeError(f"failed to stat destination file: {e}") - - with create_session() as session: - result = ingest_fs_asset( - session, - asset_hash=asset_hash, - abs_path=dest_abs, - size_bytes=size_bytes, - mtime_ns=mtime_ns, - mime_type=content_type, - info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest), - owner_id=owner_id, - preview_id=None, - user_metadata=spec.user_metadata or {}, - tags=spec.tags, - tag_origin="manual", - require_existing_tags=False, - ) - info_id = result["asset_info_id"] - if not info_id: - raise RuntimeError("failed to create asset metadata") - - pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id) - if not pair: - raise RuntimeError("inconsistent DB state after ingest") - info, asset = pair - tag_names = get_asset_tags(session, asset_info_id=info.id) - created_result = schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=asset.hash, - size=int(asset.size_bytes), - mime_type=asset.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=result["asset_created"], - ) - session.commit() - - return created_result - - -def update_asset( - *, - asset_info_id: str, - name: str | None = None, - tags: list[str] | None = None, - user_metadata: dict | None = None, - owner_id: str = "", -) -> schemas_out.AssetUpdated: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - info = update_asset_info_full( - session, - asset_info_id=asset_info_id, - name=name, - tags=tags, - user_metadata=user_metadata, - tag_origin="manual", - asset_info_row=info_row, - ) - - tag_names = get_asset_tags(session, asset_info_id=asset_info_id) - result = schemas_out.AssetUpdated( - id=info.id, - name=info.name, - asset_hash=info.asset.hash if info.asset else None, - tags=tag_names, - user_metadata=info.user_metadata or {}, - updated_at=info.updated_at, - ) - session.commit() - - return result - - -def set_asset_preview( - *, - asset_info_id: str, - preview_asset_id: str | None = None, - owner_id: str = "", -) -> schemas_out.AssetDetail: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - set_asset_info_preview( - session, - asset_info_id=asset_info_id, - preview_asset_id=preview_asset_id, - ) - - res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not res: - raise RuntimeError("State changed during preview update") - info, asset, tags = res - result = schemas_out.AssetDetail( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, - mime_type=asset.mime_type if asset else None, - tags=tags, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - ) - session.commit() - - return result - - -def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - asset_id = info_row.asset_id if info_row else None - deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not deleted: - session.commit() - return False - - if not delete_content_if_orphan or not asset_id: - session.commit() - return True - - still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id) - if still_exists: - session.commit() - return True - - states = list_cache_states_by_asset_id(session, asset_id=asset_id) - file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] - - asset_row = session.get(Asset, asset_id) - if asset_row is not None: - session.delete(asset_row) - - session.commit() - for p in file_paths: - with contextlib.suppress(Exception): - if p and os.path.isfile(p): - os.remove(p) - return True - - -def create_asset_from_hash( - *, - hash_str: str, - name: str, - tags: list[str] | None = None, - user_metadata: dict | None = None, - owner_id: str = "", -) -> schemas_out.AssetCreated | None: - canonical = hash_str.strip().lower() - with create_session() as session: - asset = get_asset_by_hash(session, asset_hash=canonical) - if not asset: - return None - - info = create_asset_info_for_existing_asset( - session, - asset_hash=canonical, - name=_safe_filename(name, fallback=canonical.split(":", 1)[1]), - user_metadata=user_metadata or {}, - tags=tags or [], - tag_origin="manual", - owner_id=owner_id, - ) - tag_names = get_asset_tags(session, asset_info_id=info.id) - result = schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=asset.hash, - size=int(asset.size_bytes), - mime_type=asset.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=False, - ) - session.commit() - - return result - - -def add_tags_to_asset( - *, - asset_info_id: str, - tags: list[str], - origin: str = "manual", - owner_id: str = "", -) -> schemas_out.TagsAdd: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - data = add_tags_to_asset_info( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=origin, - create_if_missing=True, - asset_info_row=info_row, - ) - session.commit() - return schemas_out.TagsAdd(**data) - - -def remove_tags_from_asset( - *, - asset_info_id: str, - tags: list[str], - owner_id: str = "", -) -> schemas_out.TagsRemove: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - data = remove_tags_from_asset_info( - session, - asset_info_id=asset_info_id, - tags=tags, - ) - session.commit() - return schemas_out.TagsRemove(**data) - - -def list_tags( - prefix: str | None = None, - limit: int = 100, - offset: int = 0, - order: str = "count_desc", - include_zero: bool = True, - owner_id: str = "", -) -> schemas_out.TagsList: - limit = max(1, min(1000, limit)) - offset = max(0, offset) - - with create_session() as session: - rows, total = list_tags_with_usage( - session, - prefix=prefix, - limit=limit, - offset=offset, - include_zero=include_zero, - order=order, - owner_id=owner_id, - ) - - tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] - return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 0172a5c2f..e27ea5123 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -1,263 +1,567 @@ -import contextlib -import time import logging import os -import sqlalchemy +from pathlib import Path +from typing import Callable, Literal, TypedDict import folder_paths -from app.database.db import create_session, dependencies_available -from app.assets.helpers import ( - collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path, - list_tree,prefixes_for_root, escape_like_prefix, - RootType +from app.assets.database.queries import ( + add_missing_tag_for_asset_id, + bulk_update_enrichment_level, + bulk_update_is_missing, + bulk_update_needs_verify, + delete_orphaned_seed_asset, + delete_references_by_ids, + ensure_tags_exist, + get_asset_by_hash, + get_references_for_prefixes, + get_unenriched_references, + mark_references_missing_outside_prefixes, + reassign_asset_references, + remove_missing_tag_for_asset_id, + set_reference_metadata, + update_asset_hash_and_mime, ) -from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id -from app.assets.database.bulk_ops import seed_from_paths_batch -from app.assets.database.models import Asset, AssetCacheState, AssetInfo +from app.assets.services.bulk_ingest import ( + SeedAssetSpec, + batch_insert_seed_assets, +) +from app.assets.services.file_utils import ( + get_mtime_ns, + is_visible, + list_files_recursively, + verify_file_unchanged, +) +from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash +from app.assets.services.metadata_extract import extract_file_metadata +from app.assets.services.path_utils import ( + compute_relative_filename, + get_comfy_models_folders, + get_name_and_tags_from_asset_path, +) +from app.database.db import create_session -def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None: - """ - Scan the given roots and seed the assets into the database. - """ - if not dependencies_available(): - if enable_logging: - logging.warning("Database dependencies not available, skipping assets scan") - return - t_start = time.perf_counter() - created = 0 - skipped_existing = 0 - orphans_pruned = 0 - paths: list[str] = [] - try: - existing_paths: set[str] = set() - for r in roots: - try: - survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True) - if survivors: - existing_paths.update(survivors) - except Exception as e: - logging.exception("fast DB scan failed for %s: %s", r, e) +class _RefInfo(TypedDict): + ref_id: str + file_path: str + exists: bool + stat_unchanged: bool + needs_verify: bool - try: - orphans_pruned = _prune_orphaned_assets(roots) - except Exception as e: - logging.exception("orphan pruning failed: %s", e) - if "models" in roots: - paths.extend(collect_models_files()) - if "input" in roots: - paths.extend(list_tree(folder_paths.get_input_directory())) - if "output" in roots: - paths.extend(list_tree(folder_paths.get_output_directory())) +class _AssetAccumulator(TypedDict): + hash: str | None + size_db: int + refs: list[_RefInfo] - specs: list[dict] = [] - tag_pool: set[str] = set() - for p in paths: - abs_p = os.path.abspath(p) - if abs_p in existing_paths: - skipped_existing += 1 + +RootType = Literal["models", "input", "output"] + + +def get_prefixes_for_root(root: RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + + +def get_all_known_prefixes() -> list[str]: + """Get all known asset prefixes across all root types.""" + all_roots: tuple[RootType, ...] = ("models", "input", "output") + return [p for root in all_roots for p in get_prefixes_for_root(root)] + + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + if not all(is_visible(part) for part in Path(rel_path).parts): continue - try: - stat_p = os.stat(abs_p, follow_symlinks=False) - except OSError: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: continue - # skip empty files - if not stat_p.st_size: - continue - name, tags = get_name_and_tags_from_asset_path(abs_p) - specs.append( - { - "abs_path": abs_p, - "size_bytes": stat_p.st_size, - "mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)), - "info_name": name, - "tags": tags, - "fname": compute_relative_filename(abs_p), - } - ) - for t in tags: - tag_pool.add(t) - # if no file specs, nothing to do - if not specs: - return - with create_session() as sess: - if tag_pool: - ensure_tags_exist(sess, tag_pool, tag_type="user") - - result = seed_from_paths_batch(sess, specs=specs, owner_id="") - created += result["inserted_infos"] - sess.commit() - finally: - if enable_logging: - logging.info( - "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)", - roots, - time.perf_counter() - t_start, - created, - skipped_existing, - orphans_pruned, - len(paths), - ) + abs_path = os.path.abspath(abs_path) + allowed = False + abs_p = Path(abs_path) + for b in bases: + if abs_p.is_relative_to(os.path.abspath(b)): + allowed = True + break + if allowed: + out.append(abs_path) + return out -def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int: - """Prune cache states outside configured prefixes, then delete orphaned seed assets.""" - all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)] - if not all_prefixes: - return 0 - - def make_prefix_condition(prefix: str): - base = prefix if prefix.endswith(os.sep) else prefix + os.sep - escaped, esc = escape_like_prefix(base) - return AssetCacheState.file_path.like(escaped + "%", escape=esc) - - matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes]) - - orphan_subq = ( - sqlalchemy.select(Asset.id) - .outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id) - .where(Asset.hash.is_(None), AssetCacheState.id.is_(None)) - ).scalar_subquery() - - with create_session() as sess: - sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix)) - sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq))) - result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq))) - sess.commit() - return result.rowcount - - -def _fast_db_consistency_pass( +def sync_references_with_filesystem( + session, root: RootType, - *, collect_existing_paths: bool = False, update_missing_tags: bool = False, ) -> set[str] | None: - """Fast DB+FS pass for a root: - - Toggle needs_verify per state using fast check - - For hashed assets with at least one fast-ok state in this root: delete stale missing states - - For seed assets with all states missing: delete Asset and its AssetInfos - - Optionally add/remove 'missing' tags based on fast-ok in this root - - Optionally return surviving absolute paths + """Reconcile asset references with filesystem for a root. + + - Toggle needs_verify per reference using mtime/size stat check + - For hashed assets with at least one stat-unchanged ref: delete stale missing refs + - For seed assets with all refs missing: delete Asset and its references + - Optionally add/remove 'missing' tags based on stat check in this root + - Optionally return surviving absolute paths + + Args: + session: Database session + root: Root type to scan + collect_existing_paths: If True, return set of surviving file paths + update_missing_tags: If True, update 'missing' tags based on file status + + Returns: + Set of surviving absolute paths if collect_existing_paths=True, else None """ - prefixes = prefixes_for_root(root) + prefixes = get_prefixes_for_root(root) if not prefixes: return set() if collect_existing_paths else None - conds = [] - for p in prefixes: - base = os.path.abspath(p) - if not base.endswith(os.sep): - base += os.sep - escaped, esc = escape_like_prefix(base) - conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + rows = get_references_for_prefixes( + session, prefixes, include_missing=update_missing_tags + ) + + by_asset: dict[str, _AssetAccumulator] = {} + for row in rows: + acc = by_asset.get(row.asset_id) + if acc is None: + acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []} + by_asset[row.asset_id] = acc + + stat_unchanged = False + try: + exists = True + stat_unchanged = verify_file_unchanged( + mtime_db=row.mtime_ns, + size_db=acc["size_db"], + stat_result=os.stat(row.file_path, follow_symlinks=True), + ) + except FileNotFoundError: + exists = False + except PermissionError: + exists = True + logging.debug("Permission denied accessing %s", row.file_path) + except OSError as e: + exists = False + logging.debug("OSError checking %s: %s", row.file_path, e) + + acc["refs"].append( + { + "ref_id": row.reference_id, + "file_path": row.file_path, + "exists": exists, + "stat_unchanged": stat_unchanged, + "needs_verify": row.needs_verify, + } + ) + + to_set_verify: list[str] = [] + to_clear_verify: list[str] = [] + stale_ref_ids: list[str] = [] + to_mark_missing: list[str] = [] + to_clear_missing: list[str] = [] + survivors: set[str] = set() + + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + refs = acc["refs"] + any_unchanged = any(r["stat_unchanged"] for r in refs) + all_missing = all(not r["exists"] for r in refs) + + for r in refs: + if not r["exists"]: + to_mark_missing.append(r["ref_id"]) + continue + if r["stat_unchanged"]: + to_clear_missing.append(r["ref_id"]) + if r["needs_verify"]: + to_clear_verify.append(r["ref_id"]) + if not r["stat_unchanged"] and not r["needs_verify"]: + to_set_verify.append(r["ref_id"]) + + if a_hash is None: + if refs and all_missing: + delete_orphaned_seed_asset(session, aid) + else: + for r in refs: + if r["exists"]: + survivors.add(os.path.abspath(r["file_path"])) + continue + + if any_unchanged: + for r in refs: + if not r["exists"]: + stale_ref_ids.append(r["ref_id"]) + if update_missing_tags: + try: + remove_missing_tag_for_asset_id(session, asset_id=aid) + except Exception as e: + logging.warning( + "Failed to remove missing tag for asset %s: %s", aid, e + ) + elif update_missing_tags: + try: + add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic") + except Exception as e: + logging.warning("Failed to add missing tag for asset %s: %s", aid, e) + + for r in refs: + if r["exists"]: + survivors.add(os.path.abspath(r["file_path"])) + + delete_references_by_ids(session, stale_ref_ids) + stale_set = set(stale_ref_ids) + to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set] + bulk_update_is_missing(session, to_mark_missing, value=True) + bulk_update_is_missing(session, to_clear_missing, value=False) + bulk_update_needs_verify(session, to_set_verify, value=True) + bulk_update_needs_verify(session, to_clear_verify, value=False) + + return survivors if collect_existing_paths else None + + +def sync_root_safely(root: RootType) -> set[str]: + """Sync a single root's references with the filesystem. + + Returns survivors (existing paths) or empty set on failure. + """ + try: + with create_session() as sess: + survivors = sync_references_with_filesystem( + sess, + root, + collect_existing_paths=True, + update_missing_tags=True, + ) + sess.commit() + return survivors or set() + except Exception as e: + logging.exception("fast DB scan failed for %s: %s", root, e) + return set() + + +def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int: + """Mark references as missing when outside the given prefixes. + + This is a non-destructive soft-delete. Returns count marked or 0 on failure. + """ + try: + with create_session() as sess: + count = mark_references_missing_outside_prefixes(sess, prefixes) + sess.commit() + return count + except Exception as e: + logging.exception("marking missing assets failed: %s", e) + return 0 + + +def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: + """Collect all file paths for the given roots.""" + paths: list[str] = [] + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_files_recursively(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_files_recursively(folder_paths.get_output_directory())) + return paths + + +def build_asset_specs( + paths: list[str], + existing_paths: set[str], + enable_metadata_extraction: bool = True, + compute_hashes: bool = False, +) -> tuple[list[SeedAssetSpec], set[str], int]: + """Build asset specs from paths, returning (specs, tag_pool, skipped_count). + + Args: + paths: List of file paths to process + existing_paths: Set of paths that already exist in the database + enable_metadata_extraction: If True, extract tier 1 & 2 metadata + compute_hashes: If True, compute blake3 hashes (slow for large files) + """ + specs: list[SeedAssetSpec] = [] + tag_pool: set[str] = set() + skipped = 0 + + for p in paths: + abs_p = os.path.abspath(p) + if abs_p in existing_paths: + skipped += 1 + continue + try: + stat_p = os.stat(abs_p, follow_symlinks=True) + except OSError: + continue + if not stat_p.st_size: + continue + name, tags = get_name_and_tags_from_asset_path(abs_p) + rel_fname = compute_relative_filename(abs_p) + + # Extract metadata (tier 1: filesystem, tier 2: safetensors header) + metadata = None + if enable_metadata_extraction: + metadata = extract_file_metadata( + abs_p, + stat_result=stat_p, + relative_filename=rel_fname, + ) + + # Compute hash if requested + asset_hash: str | None = None + if compute_hashes: + try: + digest, _ = compute_blake3_hash(abs_p) + asset_hash = "blake3:" + digest + except Exception as e: + logging.warning("Failed to hash %s: %s", abs_p, e) + + mime_type = metadata.content_type if metadata else None + specs.append( + { + "abs_path": abs_p, + "size_bytes": stat_p.st_size, + "mtime_ns": get_mtime_ns(stat_p), + "info_name": name, + "tags": tags, + "fname": rel_fname, + "metadata": metadata, + "hash": asset_hash, + "mime_type": mime_type, + } + ) + tag_pool.update(tags) + + return specs, tag_pool, skipped + + + +def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: + """Insert asset specs into database, returning count of created refs.""" + if not specs: + return 0 + with create_session() as sess: + if tag_pool: + ensure_tags_exist(sess, tag_pool, tag_type="user") + result = batch_insert_seed_assets(sess, specs=specs, owner_id="") + sess.commit() + return result.inserted_refs + + +# Enrichment level constants +ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only +ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type) +ENRICHMENT_HASHED = 2 # Hash computed (blake3) + + +def get_unenriched_assets_for_roots( + roots: tuple[RootType, ...], + max_level: int = ENRICHMENT_STUB, + limit: int = 1000, +) -> list: + """Get assets that need enrichment for the given roots. + + Args: + roots: Tuple of root types to scan + max_level: Maximum enrichment level to include + limit: Maximum number of rows to return + + Returns: + List of UnenrichedReferenceRow + """ + prefixes: list[str] = [] + for root in roots: + prefixes.extend(get_prefixes_for_root(root)) + + if not prefixes: + return [] with create_session() as sess: - rows = ( - sess.execute( - sqlalchemy.select( - AssetCacheState.id, - AssetCacheState.file_path, - AssetCacheState.mtime_ns, - AssetCacheState.needs_verify, - AssetCacheState.asset_id, - Asset.hash, - Asset.size_bytes, - ) - .join(Asset, Asset.id == AssetCacheState.asset_id) - .where(sqlalchemy.or_(*conds)) - .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + return get_unenriched_references( + sess, prefixes, max_level=max_level, limit=limit + ) + + +def enrich_asset( + session, + file_path: str, + reference_id: str, + asset_id: str, + extract_metadata: bool = True, + compute_hash: bool = False, + interrupt_check: Callable[[], bool] | None = None, + hash_checkpoints: dict[str, HashCheckpoint] | None = None, +) -> int: + """Enrich a single asset with metadata and/or hash. + + Args: + session: Database session (caller manages lifecycle) + file_path: Absolute path to the file + reference_id: ID of the reference to update + asset_id: ID of the asset to update (for mime_type and hash) + extract_metadata: If True, extract safetensors header and mime type + compute_hash: If True, compute blake3 hash + interrupt_check: Optional non-blocking callable that returns True if + the operation should be interrupted (e.g. paused or cancelled) + hash_checkpoints: Optional dict for saving/restoring hash progress + across interruptions, keyed by file path + + Returns: + New enrichment level achieved + """ + new_level = ENRICHMENT_STUB + + try: + stat_p = os.stat(file_path, follow_symlinks=True) + except OSError: + return new_level + + rel_fname = compute_relative_filename(file_path) + mime_type: str | None = None + metadata = None + + if extract_metadata: + metadata = extract_file_metadata( + file_path, + stat_result=stat_p, + relative_filename=rel_fname, + ) + if metadata: + mime_type = metadata.content_type + new_level = ENRICHMENT_METADATA + + full_hash: str | None = None + if compute_hash: + try: + mtime_before = get_mtime_ns(stat_p) + size_before = stat_p.st_size + + # Restore checkpoint if available and file unchanged + checkpoint = None + if hash_checkpoints is not None: + checkpoint = hash_checkpoints.get(file_path) + if checkpoint is not None: + cur_stat = os.stat(file_path, follow_symlinks=True) + if (checkpoint.mtime_ns != get_mtime_ns(cur_stat) + or checkpoint.file_size != cur_stat.st_size): + checkpoint = None + hash_checkpoints.pop(file_path, None) + else: + mtime_before = get_mtime_ns(cur_stat) + + digest, new_checkpoint = compute_blake3_hash( + file_path, + interrupt_check=interrupt_check, + checkpoint=checkpoint, ) - ).all() - by_asset: dict[str, dict] = {} - for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: - acc = by_asset.get(aid) - if acc is None: - acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} - by_asset[aid] = acc + if digest is None: + # Interrupted — save checkpoint for later resumption + if hash_checkpoints is not None and new_checkpoint is not None: + new_checkpoint.mtime_ns = mtime_before + new_checkpoint.file_size = size_before + hash_checkpoints[file_path] = new_checkpoint + return new_level + + # Completed — clear any saved checkpoint + if hash_checkpoints is not None: + hash_checkpoints.pop(file_path, None) + + stat_after = os.stat(file_path, follow_symlinks=True) + mtime_after = get_mtime_ns(stat_after) + if mtime_before != mtime_after: + logging.warning("File modified during hashing, discarding hash: %s", file_path) + else: + full_hash = f"blake3:{digest}" + metadata_ok = not extract_metadata or metadata is not None + if metadata_ok: + new_level = ENRICHMENT_HASHED + except Exception as e: + logging.warning("Failed to hash %s: %s", file_path, e) + + if extract_metadata and metadata: + user_metadata = metadata.to_user_metadata() + set_reference_metadata(session, reference_id, user_metadata) + + if full_hash: + existing = get_asset_by_hash(session, full_hash) + if existing and existing.id != asset_id: + reassign_asset_references(session, asset_id, existing.id, reference_id) + delete_orphaned_seed_asset(session, asset_id) + if mime_type: + update_asset_hash_and_mime(session, existing.id, mime_type=mime_type) + else: + update_asset_hash_and_mime(session, asset_id, full_hash, mime_type) + elif mime_type: + update_asset_hash_and_mime(session, asset_id, mime_type=mime_type) + + bulk_update_enrichment_level(session, [reference_id], new_level) + session.commit() + + return new_level + + +def enrich_assets_batch( + rows: list, + extract_metadata: bool = True, + compute_hash: bool = False, + interrupt_check: Callable[[], bool] | None = None, + hash_checkpoints: dict[str, HashCheckpoint] | None = None, +) -> tuple[int, list[str]]: + """Enrich a batch of assets. + + Uses a single DB session for the entire batch, committing after each + individual asset to avoid long-held transactions while eliminating + per-asset session creation overhead. + + Args: + rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots + extract_metadata: If True, extract metadata for each asset + compute_hash: If True, compute hash for each asset + interrupt_check: Optional non-blocking callable that returns True if + the operation should be interrupted (e.g. paused or cancelled) + hash_checkpoints: Optional dict for saving/restoring hash progress + across interruptions, keyed by file path + + Returns: + Tuple of (enriched_count, failed_reference_ids) + """ + enriched = 0 + failed_ids: list[str] = [] + + with create_session() as sess: + for row in rows: + if interrupt_check is not None and interrupt_check(): + break - fast_ok = False try: - exists = True - fast_ok = fast_asset_file_check( - mtime_db=mtime_db, - size_db=acc["size_db"], - stat_result=os.stat(fp, follow_symlinks=True), + new_level = enrich_asset( + sess, + file_path=row.file_path, + reference_id=row.reference_id, + asset_id=row.asset_id, + extract_metadata=extract_metadata, + compute_hash=compute_hash, + interrupt_check=interrupt_check, + hash_checkpoints=hash_checkpoints, ) - except FileNotFoundError: - exists = False - except OSError: - exists = False - - acc["states"].append({ - "sid": sid, - "fp": fp, - "exists": exists, - "fast_ok": fast_ok, - "needs_verify": bool(needs_verify), - }) - - to_set_verify: list[int] = [] - to_clear_verify: list[int] = [] - stale_state_ids: list[int] = [] - survivors: set[str] = set() - - for aid, acc in by_asset.items(): - a_hash = acc["hash"] - states = acc["states"] - any_fast_ok = any(s["fast_ok"] for s in states) - all_missing = all(not s["exists"] for s in states) - - for s in states: - if not s["exists"]: - continue - if s["fast_ok"] and s["needs_verify"]: - to_clear_verify.append(s["sid"]) - if not s["fast_ok"] and not s["needs_verify"]: - to_set_verify.append(s["sid"]) - - if a_hash is None: - if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists - sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid)) - asset = sess.get(Asset, aid) - if asset: - sess.delete(asset) + if new_level > row.enrichment_level: + enriched += 1 else: - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["fp"])) - continue + failed_ids.append(row.reference_id) + except Exception as e: + logging.warning("Failed to enrich %s: %s", row.file_path, e) + sess.rollback() + failed_ids.append(row.reference_id) - if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records - for s in states: - if not s["exists"]: - stale_state_ids.append(s["sid"]) - if update_missing_tags: - with contextlib.suppress(Exception): - remove_missing_tag_for_asset_id(sess, asset_id=aid) - elif update_missing_tags: - with contextlib.suppress(Exception): - add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") - - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["fp"])) - - if stale_state_ids: - sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) - if to_set_verify: - sess.execute( - sqlalchemy.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_set_verify)) - .values(needs_verify=True) - ) - if to_clear_verify: - sess.execute( - sqlalchemy.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_clear_verify)) - .values(needs_verify=False) - ) - sess.commit() - return survivors if collect_existing_paths else None + return enriched, failed_ids diff --git a/app/assets/seeder.py b/app/assets/seeder.py new file mode 100644 index 000000000..029448464 --- /dev/null +++ b/app/assets/seeder.py @@ -0,0 +1,794 @@ +"""Background asset seeder with thread management and cancellation support.""" + +import logging +import os +import threading +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Callable + +from app.assets.scanner import ( + ENRICHMENT_METADATA, + ENRICHMENT_STUB, + RootType, + build_asset_specs, + collect_paths_for_roots, + enrich_assets_batch, + get_all_known_prefixes, + get_prefixes_for_root, + get_unenriched_assets_for_roots, + insert_asset_specs, + mark_missing_outside_prefixes_safely, + sync_root_safely, +) +from app.database.db import dependencies_available + + +class ScanInProgressError(Exception): + """Raised when an operation cannot proceed because a scan is running.""" + + +class State(Enum): + """Seeder state machine states.""" + + IDLE = "IDLE" + RUNNING = "RUNNING" + PAUSED = "PAUSED" + CANCELLING = "CANCELLING" + + +class ScanPhase(Enum): + """Scan phase options.""" + + FAST = "fast" # Phase 1: filesystem only (stubs) + ENRICH = "enrich" # Phase 2: metadata + hash + FULL = "full" # Both phases sequentially + + +@dataclass +class Progress: + """Progress information for a scan operation.""" + + scanned: int = 0 + total: int = 0 + created: int = 0 + skipped: int = 0 + + +@dataclass +class ScanStatus: + """Current status of the asset seeder.""" + + state: State + progress: Progress | None + errors: list[str] = field(default_factory=list) + + +ProgressCallback = Callable[[Progress], None] + + +class _AssetSeeder: + """Background asset scanning manager. + + Spawns ephemeral daemon threads for scanning. + Each scan creates a new thread that exits when complete. + Use the module-level ``asset_seeder`` instance. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = State.IDLE + self._progress: Progress | None = None + self._last_progress: Progress | None = None + self._errors: list[str] = [] + self._thread: threading.Thread | None = None + self._cancel_event = threading.Event() + self._run_gate = threading.Event() + self._run_gate.set() # Start unpaused (set = running, clear = paused) + self._roots: tuple[RootType, ...] = () + self._phase: ScanPhase = ScanPhase.FULL + self._compute_hashes: bool = False + self._prune_first: bool = False + self._progress_callback: ProgressCallback | None = None + self._disabled: bool = False + + def disable(self) -> None: + """Disable the asset seeder, preventing any scans from starting.""" + self._disabled = True + logging.info("Asset seeder disabled") + + def is_disabled(self) -> bool: + """Check if the asset seeder is disabled.""" + return self._disabled + + def start( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + phase: ScanPhase = ScanPhase.FULL, + progress_callback: ProgressCallback | None = None, + prune_first: bool = False, + compute_hashes: bool = False, + ) -> bool: + """Start a background scan for the given roots. + + Args: + roots: Tuple of root types to scan (models, input, output) + phase: Scan phase to run (FAST, ENRICH, or FULL for both) + progress_callback: Optional callback called with progress updates + prune_first: If True, prune orphaned assets before scanning + compute_hashes: If True, compute blake3 hashes (slow) + + Returns: + True if scan was started, False if already running + """ + if self._disabled: + logging.debug("Asset seeder is disabled, skipping start") + return False + logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value) + with self._lock: + if self._state != State.IDLE: + logging.info("Asset seeder already running, skipping start") + return False + self._state = State.RUNNING + self._progress = Progress() + self._errors = [] + self._roots = roots + self._phase = phase + self._prune_first = prune_first + self._compute_hashes = compute_hashes + self._progress_callback = progress_callback + self._cancel_event.clear() + self._run_gate.set() # Ensure unpaused when starting + self._thread = threading.Thread( + target=self._run_scan, + name="_AssetSeeder", + daemon=True, + ) + self._thread.start() + return True + + def start_fast( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + progress_callback: ProgressCallback | None = None, + prune_first: bool = False, + ) -> bool: + """Start a fast scan (phase 1 only) - creates stub records. + + Args: + roots: Tuple of root types to scan + progress_callback: Optional callback for progress updates + prune_first: If True, prune orphaned assets before scanning + + Returns: + True if scan was started, False if already running + """ + return self.start( + roots=roots, + phase=ScanPhase.FAST, + progress_callback=progress_callback, + prune_first=prune_first, + compute_hashes=False, + ) + + def start_enrich( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + progress_callback: ProgressCallback | None = None, + compute_hashes: bool = False, + ) -> bool: + """Start an enrichment scan (phase 2 only) - extracts metadata and hashes. + + Args: + roots: Tuple of root types to scan + progress_callback: Optional callback for progress updates + compute_hashes: If True, compute blake3 hashes + + Returns: + True if scan was started, False if already running + """ + return self.start( + roots=roots, + phase=ScanPhase.ENRICH, + progress_callback=progress_callback, + prune_first=False, + compute_hashes=compute_hashes, + ) + + def cancel(self) -> bool: + """Request cancellation of the current scan. + + Returns: + True if cancellation was requested, False if not running or paused + """ + with self._lock: + if self._state not in (State.RUNNING, State.PAUSED): + return False + logging.info("Asset seeder cancelling (was %s)", self._state.value) + self._state = State.CANCELLING + self._cancel_event.set() + self._run_gate.set() # Unblock if paused so thread can exit + return True + + def stop(self) -> bool: + """Stop the current scan (alias for cancel). + + Returns: + True if stop was requested, False if not running + """ + return self.cancel() + + def pause(self) -> bool: + """Pause the current scan. + + The scan will complete its current batch before pausing. + + Returns: + True if pause was requested, False if not running + """ + with self._lock: + if self._state != State.RUNNING: + return False + logging.info("Asset seeder pausing") + self._state = State.PAUSED + self._run_gate.clear() + return True + + def resume(self) -> bool: + """Resume a paused scan. + + This is a noop if the scan is not in the PAUSED state + + Returns: + True if resumed, False if not paused + """ + with self._lock: + if self._state != State.PAUSED: + return False + logging.info("Asset seeder resuming") + self._state = State.RUNNING + self._run_gate.set() + self._emit_event("assets.seed.resumed", {}) + return True + + def restart( + self, + roots: tuple[RootType, ...] | None = None, + phase: ScanPhase | None = None, + progress_callback: ProgressCallback | None = None, + prune_first: bool | None = None, + compute_hashes: bool | None = None, + timeout: float = 5.0, + ) -> bool: + """Cancel any running scan and start a new one. + + Args: + roots: Roots to scan (defaults to previous roots) + phase: Scan phase (defaults to previous phase) + progress_callback: Progress callback (defaults to previous) + prune_first: Prune before scan (defaults to previous) + compute_hashes: Compute hashes (defaults to previous) + timeout: Max seconds to wait for current scan to stop + + Returns: + True if new scan was started, False if failed to stop previous + """ + logging.info("Asset seeder restart requested") + with self._lock: + prev_roots = self._roots + prev_phase = self._phase + prev_callback = self._progress_callback + prev_prune = self._prune_first + prev_hashes = self._compute_hashes + + self.cancel() + if not self.wait(timeout=timeout): + return False + + cb = progress_callback if progress_callback is not None else prev_callback + return self.start( + roots=roots if roots is not None else prev_roots, + phase=phase if phase is not None else prev_phase, + progress_callback=cb, + prune_first=prune_first if prune_first is not None else prev_prune, + compute_hashes=( + compute_hashes if compute_hashes is not None else prev_hashes + ), + ) + + def wait(self, timeout: float | None = None) -> bool: + """Wait for the current scan to complete. + + Args: + timeout: Maximum seconds to wait, or None for no timeout + + Returns: + True if scan completed, False if timeout expired or no scan running + """ + with self._lock: + thread = self._thread + if thread is None: + return True + thread.join(timeout=timeout) + return not thread.is_alive() + + def get_status(self) -> ScanStatus: + """Get the current status and progress of the seeder.""" + with self._lock: + src = self._progress or self._last_progress + return ScanStatus( + state=self._state, + progress=Progress( + scanned=src.scanned, + total=src.total, + created=src.created, + skipped=src.skipped, + ) + if src + else None, + errors=list(self._errors), + ) + + def shutdown(self, timeout: float = 5.0) -> None: + """Gracefully shutdown: cancel any running scan and wait for thread. + + Args: + timeout: Maximum seconds to wait for thread to exit + """ + self.cancel() + self.wait(timeout=timeout) + with self._lock: + self._thread = None + + def mark_missing_outside_prefixes(self) -> int: + """Mark references as missing when outside all known root prefixes. + + This is a non-destructive soft-delete operation. Assets and their + metadata are preserved, but references are flagged as missing. + They can be restored if the file reappears in a future scan. + + This operation is decoupled from scanning to prevent partial scans + from accidentally marking assets belonging to other roots. + + Should be called explicitly when cleanup is desired, typically after + a full scan of all roots or during maintenance. + + Returns: + Number of references marked as missing + + Raises: + ScanInProgressError: If a scan is currently running + """ + with self._lock: + if self._state != State.IDLE: + raise ScanInProgressError( + "Cannot mark missing assets while scan is running" + ) + self._state = State.RUNNING + + try: + if not dependencies_available(): + logging.warning( + "Database dependencies not available, skipping mark missing" + ) + return 0 + + all_prefixes = get_all_known_prefixes() + marked = mark_missing_outside_prefixes_safely(all_prefixes) + if marked > 0: + logging.info("Marked %d references as missing", marked) + return marked + finally: + with self._lock: + self._last_progress = self._progress + self._state = State.IDLE + self._progress = None + + def _is_cancelled(self) -> bool: + """Check if cancellation has been requested.""" + return self._cancel_event.is_set() + + def _is_paused_or_cancelled(self) -> bool: + """Non-blocking check: True if paused or cancelled. + + Use as interrupt_check for I/O-bound work (e.g. hashing) so that + file handles are released immediately on pause rather than held + open while blocked. The caller is responsible for blocking on + _check_pause_and_cancel() afterward. + """ + return not self._run_gate.is_set() or self._cancel_event.is_set() + + def _check_pause_and_cancel(self) -> bool: + """Block while paused, then check if cancelled. + + Call this at checkpoint locations in scan loops. It will: + 1. Block indefinitely while paused (until resume or cancel) + 2. Return True if cancelled, False to continue + + Returns: + True if scan should stop, False to continue + """ + if not self._run_gate.is_set(): + self._emit_event("assets.seed.paused", {}) + self._run_gate.wait() # Blocks if paused + return self._is_cancelled() + + def _emit_event(self, event_type: str, data: dict) -> None: + """Emit a WebSocket event if server is available.""" + try: + from server import PromptServer + + if hasattr(PromptServer, "instance") and PromptServer.instance: + PromptServer.instance.send_sync(event_type, data) + except Exception: + pass + + def _update_progress( + self, + scanned: int | None = None, + total: int | None = None, + created: int | None = None, + skipped: int | None = None, + ) -> None: + """Update progress counters (thread-safe).""" + callback: ProgressCallback | None = None + progress: Progress | None = None + + with self._lock: + if self._progress is None: + return + if scanned is not None: + self._progress.scanned = scanned + if total is not None: + self._progress.total = total + if created is not None: + self._progress.created = created + if skipped is not None: + self._progress.skipped = skipped + if self._progress_callback: + callback = self._progress_callback + progress = Progress( + scanned=self._progress.scanned, + total=self._progress.total, + created=self._progress.created, + skipped=self._progress.skipped, + ) + + if callback and progress: + try: + callback(progress) + except Exception: + pass + + _MAX_ERRORS = 200 + + def _add_error(self, message: str) -> None: + """Add an error message (thread-safe), capped at _MAX_ERRORS.""" + with self._lock: + if len(self._errors) < self._MAX_ERRORS: + self._errors.append(message) + + def _log_scan_config(self, roots: tuple[RootType, ...]) -> None: + """Log the directories that will be scanned.""" + import folder_paths + + for root in roots: + if root == "models": + logging.info( + "Asset scan [models] directory: %s", + os.path.abspath(folder_paths.models_dir), + ) + else: + prefixes = get_prefixes_for_root(root) + if prefixes: + logging.info("Asset scan [%s] directories: %s", root, prefixes) + + def _run_scan(self) -> None: + """Main scan loop running in background thread.""" + t_start = time.perf_counter() + roots = self._roots + phase = self._phase + cancelled = False + total_created = 0 + total_enriched = 0 + skipped_existing = 0 + total_paths = 0 + + try: + if not dependencies_available(): + self._add_error("Database dependencies not available") + self._emit_event( + "assets.seed.error", + {"message": "Database dependencies not available"}, + ) + return + + if self._prune_first: + all_prefixes = get_all_known_prefixes() + marked = mark_missing_outside_prefixes_safely(all_prefixes) + if marked > 0: + logging.info("Marked %d refs as missing before scan", marked) + + if self._check_pause_and_cancel(): + logging.info("Asset scan cancelled after pruning phase") + cancelled = True + return + + self._log_scan_config(roots) + + # Phase 1: Fast scan (stub records) + if phase in (ScanPhase.FAST, ScanPhase.FULL): + created, skipped, paths = self._run_fast_phase(roots) + total_created, skipped_existing, total_paths = created, skipped, paths + + if self._check_pause_and_cancel(): + cancelled = True + return + + self._emit_event( + "assets.seed.fast_complete", + { + "roots": list(roots), + "created": total_created, + "skipped": skipped_existing, + "total": total_paths, + }, + ) + + # Phase 2: Enrichment scan (metadata + hashes) + if phase in (ScanPhase.ENRICH, ScanPhase.FULL): + if self._check_pause_and_cancel(): + cancelled = True + return + + enrich_cancelled, total_enriched = self._run_enrich_phase(roots) + + if enrich_cancelled: + cancelled = True + return + + self._emit_event( + "assets.seed.enrich_complete", + { + "roots": list(roots), + "enriched": total_enriched, + }, + ) + + elapsed = time.perf_counter() - t_start + logging.info( + "Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d", + roots, + phase.value, + elapsed, + total_created, + total_enriched, + skipped_existing, + ) + + self._emit_event( + "assets.seed.completed", + { + "phase": phase.value, + "total": total_paths, + "created": total_created, + "enriched": total_enriched, + "skipped": skipped_existing, + "elapsed": round(elapsed, 3), + }, + ) + + except Exception as e: + self._add_error(f"Scan failed: {e}") + logging.exception("Asset scan failed") + self._emit_event("assets.seed.error", {"message": str(e)}) + finally: + if cancelled: + self._emit_event( + "assets.seed.cancelled", + { + "scanned": self._progress.scanned if self._progress else 0, + "total": total_paths, + "created": total_created, + }, + ) + with self._lock: + self._last_progress = self._progress + self._state = State.IDLE + self._progress = None + + def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]: + """Run phase 1: fast scan to create stub records. + + Returns: + Tuple of (total_created, skipped_existing, total_paths) + """ + t_fast_start = time.perf_counter() + total_created = 0 + skipped_existing = 0 + + existing_paths: set[str] = set() + t_sync = time.perf_counter() + for r in roots: + if self._check_pause_and_cancel(): + return total_created, skipped_existing, 0 + existing_paths.update(sync_root_safely(r)) + logging.debug( + "Fast scan: sync_root phase took %.3fs (%d existing paths)", + time.perf_counter() - t_sync, + len(existing_paths), + ) + + if self._check_pause_and_cancel(): + return total_created, skipped_existing, 0 + + t_collect = time.perf_counter() + paths = collect_paths_for_roots(roots) + logging.debug( + "Fast scan: collect_paths took %.3fs (%d paths found)", + time.perf_counter() - t_collect, + len(paths), + ) + total_paths = len(paths) + self._update_progress(total=total_paths) + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "total": total_paths, "phase": "fast"}, + ) + + # Use stub specs (no metadata extraction, no hashing) + t_specs = time.perf_counter() + specs, tag_pool, skipped_existing = build_asset_specs( + paths, + existing_paths, + enable_metadata_extraction=False, + compute_hashes=False, + ) + logging.debug( + "Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)", + time.perf_counter() - t_specs, + len(specs), + skipped_existing, + ) + self._update_progress(skipped=skipped_existing) + + if self._check_pause_and_cancel(): + return total_created, skipped_existing, total_paths + + batch_size = 500 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + for i in range(0, len(specs), batch_size): + if self._check_pause_and_cancel(): + logging.info( + "Fast scan cancelled after %d/%d files (created=%d)", + i, + len(specs), + total_created, + ) + return total_created, skipped_existing, total_paths + + batch = specs[i : i + batch_size] + batch_tags = {t for spec in batch for t in spec["tags"]} + try: + created = insert_asset_specs(batch, batch_tags) + total_created += created + except Exception as e: + self._add_error(f"Batch insert failed at offset {i}: {e}") + logging.exception("Batch insert failed at offset %d", i) + + scanned = i + len(batch) + now = time.perf_counter() + self._update_progress(scanned=scanned, created=total_created) + + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + { + "phase": "fast", + "scanned": scanned, + "total": len(specs), + "created": total_created, + }, + ) + last_progress_time = now + + self._update_progress(scanned=len(specs), created=total_created) + logging.info( + "Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)", + time.perf_counter() - t_fast_start, + total_created, + skipped_existing, + total_paths, + ) + return total_created, skipped_existing, total_paths + + def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]: + """Run phase 2: enrich existing records with metadata and hashes. + + Returns: + Tuple of (cancelled, total_enriched) + """ + total_enriched = 0 + batch_size = 100 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + # Get the target enrichment level based on compute_hashes + if not self._compute_hashes: + target_max_level = ENRICHMENT_STUB + else: + target_max_level = ENRICHMENT_METADATA + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "phase": "enrich"}, + ) + + skip_ids: set[str] = set() + consecutive_empty = 0 + max_consecutive_empty = 3 + + # Hash checkpoints survive across batches so interrupted hashes + # can be resumed without re-reading the entire file. + hash_checkpoints: dict[str, object] = {} + + while True: + if self._check_pause_and_cancel(): + logging.info("Enrich scan cancelled after %d assets", total_enriched) + return True, total_enriched + + # Fetch next batch of unenriched assets + unenriched = get_unenriched_assets_for_roots( + roots, + max_level=target_max_level, + limit=batch_size, + ) + + # Filter out previously failed references + if skip_ids: + unenriched = [r for r in unenriched if r.reference_id not in skip_ids] + + if not unenriched: + break + + enriched, failed_ids = enrich_assets_batch( + unenriched, + extract_metadata=True, + compute_hash=self._compute_hashes, + interrupt_check=self._is_paused_or_cancelled, + hash_checkpoints=hash_checkpoints, + ) + total_enriched += enriched + skip_ids.update(failed_ids) + + if enriched == 0: + consecutive_empty += 1 + if consecutive_empty >= max_consecutive_empty: + logging.warning( + "Enrich phase stopping: %d consecutive batches with no progress (%d skipped)", + consecutive_empty, + len(skip_ids), + ) + break + else: + consecutive_empty = 0 + + now = time.perf_counter() + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + { + "phase": "enrich", + "enriched": total_enriched, + }, + ) + last_progress_time = now + + return False, total_enriched + + +asset_seeder = _AssetSeeder() diff --git a/app/assets/services/__init__.py b/app/assets/services/__init__.py new file mode 100644 index 000000000..11fcb4122 --- /dev/null +++ b/app/assets/services/__init__.py @@ -0,0 +1,87 @@ +from app.assets.services.asset_management import ( + asset_exists, + delete_asset_reference, + get_asset_by_hash, + get_asset_detail, + list_assets_page, + resolve_asset_for_download, + set_asset_preview, + update_asset_metadata, +) +from app.assets.services.bulk_ingest import ( + BulkInsertResult, + batch_insert_seed_assets, + cleanup_unreferenced_assets, +) +from app.assets.services.file_utils import ( + get_mtime_ns, + get_size_and_mtime_ns, + list_files_recursively, + verify_file_unchanged, +) +from app.assets.services.ingest import ( + DependencyMissingError, + HashMismatchError, + create_from_hash, + upload_from_temp_path, +) +from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, +) +from app.assets.services.schemas import ( + AssetData, + AssetDetailResult, + AssetSummaryData, + DownloadResolutionResult, + IngestResult, + ListAssetsResult, + ReferenceData, + RegisterAssetResult, + TagUsage, + UploadResult, + UserMetadata, +) +from app.assets.services.tagging import ( + apply_tags, + list_tags, + remove_tags, +) + +__all__ = [ + "AddTagsResult", + "AssetData", + "AssetDetailResult", + "AssetSummaryData", + "ReferenceData", + "BulkInsertResult", + "DependencyMissingError", + "DownloadResolutionResult", + "HashMismatchError", + "IngestResult", + "ListAssetsResult", + "RegisterAssetResult", + "RemoveTagsResult", + "TagUsage", + "UploadResult", + "UserMetadata", + "apply_tags", + "asset_exists", + "batch_insert_seed_assets", + "create_from_hash", + "delete_asset_reference", + "get_asset_by_hash", + "get_asset_detail", + "get_mtime_ns", + "get_size_and_mtime_ns", + "list_assets_page", + "list_files_recursively", + "list_tags", + "cleanup_unreferenced_assets", + "remove_tags", + "resolve_asset_for_download", + "set_asset_preview", + "update_asset_metadata", + "upload_from_temp_path", + "verify_file_unchanged", +] diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py new file mode 100644 index 000000000..3fe7115c8 --- /dev/null +++ b/app/assets/services/asset_management.py @@ -0,0 +1,309 @@ +import contextlib +import mimetypes +import os +from typing import Sequence + + +from app.assets.database.models import Asset +from app.assets.database.queries import ( + asset_exists_by_hash, + reference_exists_for_asset_id, + delete_reference_by_id, + fetch_reference_and_asset, + soft_delete_reference_by_id, + fetch_reference_asset_and_tags, + get_asset_by_hash as queries_get_asset_by_hash, + get_reference_by_id, + get_reference_with_owner_check, + list_references_page, + list_references_by_asset_id, + set_reference_metadata, + set_reference_preview, + set_reference_tags, + update_reference_access_time, + update_reference_name, + update_reference_updated_at, +) +from app.assets.helpers import select_best_live_path +from app.assets.services.path_utils import compute_relative_filename +from app.assets.services.schemas import ( + AssetData, + AssetDetailResult, + AssetSummaryData, + DownloadResolutionResult, + ListAssetsResult, + UserMetadata, + extract_asset_data, + extract_reference_data, +) +from app.database.db import create_session + + +def get_asset_detail( + reference_id: str, + owner_id: str = "", +) -> AssetDetailResult | None: + with create_session() as session: + result = fetch_reference_asset_and_tags( + session, + reference_id=reference_id, + owner_id=owner_id, + ) + if not result: + return None + + ref, asset, tags = result + return AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tags, + ) + + +def update_asset_metadata( + reference_id: str, + name: str | None = None, + tags: Sequence[str] | None = None, + user_metadata: UserMetadata = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> AssetDetailResult: + with create_session() as session: + ref = get_reference_with_owner_check(session, reference_id, owner_id) + + touched = False + if name is not None and name != ref.name: + update_reference_name(session, reference_id=reference_id, name=name) + touched = True + + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + + new_meta: dict | None = None + if user_metadata is not None: + new_meta = dict(user_metadata) + elif computed_filename: + current_meta = ref.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + + if new_meta is not None: + if computed_filename: + new_meta["filename"] = computed_filename + set_reference_metadata( + session, reference_id=reference_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + set_reference_tags( + session, + reference_id=reference_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if touched and user_metadata is None: + update_reference_updated_at(session, reference_id=reference_id) + + result = fetch_reference_asset_and_tags( + session, + reference_id=reference_id, + owner_id=owner_id, + ) + if not result: + raise RuntimeError("State changed during update") + + ref, asset, tag_list = result + detail = AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_list, + ) + session.commit() + + return detail + + +def delete_asset_reference( + reference_id: str, + owner_id: str, + delete_content_if_orphan: bool = True, +) -> bool: + with create_session() as session: + if not delete_content_if_orphan: + # Soft delete: mark the reference as deleted but keep everything + deleted = soft_delete_reference_by_id( + session, reference_id=reference_id, owner_id=owner_id + ) + session.commit() + return deleted + + ref_row = get_reference_by_id(session, reference_id=reference_id) + asset_id = ref_row.asset_id if ref_row else None + file_path = ref_row.file_path if ref_row else None + + deleted = delete_reference_by_id( + session, reference_id=reference_id, owner_id=owner_id + ) + if not deleted: + session.commit() + return False + + if not asset_id: + session.commit() + return True + + still_exists = reference_exists_for_asset_id(session, asset_id=asset_id) + if still_exists: + session.commit() + return True + + # Orphaned asset - delete it and its files + refs = list_references_by_asset_id(session, asset_id=asset_id) + file_paths = [ + r.file_path for r in (refs or []) if getattr(r, "file_path", None) + ] + # Also include the just-deleted file path + if file_path: + file_paths.append(file_path) + + asset_row = session.get(Asset, asset_id) + if asset_row is not None: + session.delete(asset_row) + + session.commit() + + # Delete files after commit + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + + return True + + +def set_asset_preview( + reference_id: str, + preview_asset_id: str | None = None, + owner_id: str = "", +) -> AssetDetailResult: + with create_session() as session: + get_reference_with_owner_check(session, reference_id, owner_id) + + set_reference_preview( + session, + reference_id=reference_id, + preview_asset_id=preview_asset_id, + ) + + result = fetch_reference_asset_and_tags( + session, reference_id=reference_id, owner_id=owner_id + ) + if not result: + raise RuntimeError("State changed during preview update") + + ref, asset, tags = result + detail = AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tags, + ) + session.commit() + + return detail + + +def asset_exists(asset_hash: str) -> bool: + with create_session() as session: + return asset_exists_by_hash(session, asset_hash=asset_hash) + + +def get_asset_by_hash(asset_hash: str) -> AssetData | None: + with create_session() as session: + asset = queries_get_asset_by_hash(session, asset_hash=asset_hash) + return extract_asset_data(asset) + + +def list_assets_page( + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> ListAssetsResult: + with create_session() as session: + refs, tag_map, total = list_references_page( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + items: list[AssetSummaryData] = [] + for ref in refs: + items.append( + AssetSummaryData( + ref=extract_reference_data(ref), + asset=extract_asset_data(ref.asset), + tags=tag_map.get(ref.id, []), + ) + ) + + return ListAssetsResult(items=items, total=total) + + +def resolve_asset_for_download( + reference_id: str, + owner_id: str = "", +) -> DownloadResolutionResult: + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise ValueError(f"AssetReference {reference_id} not found") + + ref, asset = pair + + # For references with file_path, use that directly + if ref.file_path and os.path.isfile(ref.file_path): + abs_path = ref.file_path + else: + # For API-created refs without file_path, find a path from other refs + refs = list_references_by_asset_id(session, asset_id=asset.id) + abs_path = select_best_live_path(refs) + if not abs_path: + raise FileNotFoundError( + f"No live path for AssetReference {reference_id} " + f"(asset id={asset.id}, name={ref.name})" + ) + + # Capture ORM attributes before commit (commit expires loaded objects) + ref_name = ref.name + asset_mime = asset.mime_type + + update_reference_access_time(session, reference_id=reference_id) + session.commit() + + ctype = ( + asset_mime + or mimetypes.guess_type(ref_name or abs_path)[0] + or "application/octet-stream" + ) + download_name = ref_name or os.path.basename(abs_path) + return DownloadResolutionResult( + abs_path=abs_path, + content_type=ctype, + download_name=download_name, + ) diff --git a/app/assets/services/bulk_ingest.py b/app/assets/services/bulk_ingest.py new file mode 100644 index 000000000..54e72730c --- /dev/null +++ b/app/assets/services/bulk_ingest.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import os +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any, TypedDict + +from sqlalchemy.orm import Session + +from app.assets.database.queries import ( + bulk_insert_assets, + bulk_insert_references_ignore_conflicts, + bulk_insert_tags_and_meta, + delete_assets_by_ids, + get_existing_asset_ids, + get_reference_ids_by_ids, + get_references_by_paths_and_asset_ids, + get_unreferenced_unhashed_asset_ids, + restore_references_by_paths, +) +from app.assets.helpers import get_utc_now + +if TYPE_CHECKING: + from app.assets.services.metadata_extract import ExtractedMetadata + + +class SeedAssetSpec(TypedDict): + """Spec for seeding an asset from filesystem.""" + + abs_path: str + size_bytes: int + mtime_ns: int + info_name: str + tags: list[str] + fname: str + metadata: ExtractedMetadata | None + hash: str | None + mime_type: str | None + + +class AssetRow(TypedDict): + """Row data for inserting an Asset.""" + + id: str + hash: str | None + size_bytes: int + mime_type: str | None + created_at: datetime + + +class ReferenceRow(TypedDict): + """Row data for inserting an AssetReference.""" + + id: str + asset_id: str + file_path: str + mtime_ns: int + owner_id: str + name: str + preview_id: str | None + user_metadata: dict[str, Any] | None + created_at: datetime + updated_at: datetime + last_access_time: datetime + + +class TagRow(TypedDict): + """Row data for inserting a Tag.""" + + asset_reference_id: str + tag_name: str + origin: str + added_at: datetime + + +class MetadataRow(TypedDict): + """Row data for inserting asset metadata.""" + + asset_reference_id: str + key: str + ordinal: int + val_str: str | None + val_num: float | None + val_bool: bool | None + val_json: dict[str, Any] | None + + +@dataclass +class BulkInsertResult: + """Result of bulk asset insertion.""" + + inserted_refs: int + won_paths: int + lost_paths: int + + +def batch_insert_seed_assets( + session: Session, + specs: list[SeedAssetSpec], + owner_id: str = "", +) -> BulkInsertResult: + """Seed assets from filesystem specs in batch. + + Each spec is a dict with keys: + - abs_path: str + - size_bytes: int + - mtime_ns: int + - info_name: str + - tags: list[str] + - fname: Optional[str] + + This function orchestrates: + 1. Insert seed Assets (hash=NULL) + 2. Claim references with ON CONFLICT DO NOTHING on file_path + 3. Query to find winners (paths where our asset_id was inserted) + 4. Delete Assets for losers (path already claimed by another asset) + 5. Insert tags and metadata for successfully inserted references + + Returns: + BulkInsertResult with inserted_refs, won_paths, lost_paths + """ + if not specs: + return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0) + + current_time = get_utc_now() + asset_rows: list[AssetRow] = [] + reference_rows: list[ReferenceRow] = [] + path_to_asset_id: dict[str, str] = {} + asset_id_to_ref_data: dict[str, dict] = {} + absolute_path_list: list[str] = [] + + for spec in specs: + absolute_path = os.path.abspath(spec["abs_path"]) + asset_id = str(uuid.uuid4()) + reference_id = str(uuid.uuid4()) + absolute_path_list.append(absolute_path) + path_to_asset_id[absolute_path] = asset_id + + mime_type = spec.get("mime_type") + asset_rows.append( + { + "id": asset_id, + "hash": spec.get("hash"), + "size_bytes": spec["size_bytes"], + "mime_type": mime_type, + "created_at": current_time, + } + ) + + # Build user_metadata from extracted metadata or fallback to filename + extracted_metadata = spec.get("metadata") + if extracted_metadata: + user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata() + elif spec["fname"]: + user_metadata = {"filename": spec["fname"]} + else: + user_metadata = None + + reference_rows.append( + { + "id": reference_id, + "asset_id": asset_id, + "file_path": absolute_path, + "mtime_ns": spec["mtime_ns"], + "owner_id": owner_id, + "name": spec["info_name"], + "preview_id": None, + "user_metadata": user_metadata, + "created_at": current_time, + "updated_at": current_time, + "last_access_time": current_time, + } + ) + + asset_id_to_ref_data[asset_id] = { + "reference_id": reference_id, + "tags": spec["tags"], + "filename": spec["fname"], + "extracted_metadata": extracted_metadata, + } + + bulk_insert_assets(session, asset_rows) + + # Filter reference rows to only those whose assets were actually inserted + # (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING) + inserted_asset_ids = get_existing_asset_ids( + session, [r["asset_id"] for r in reference_rows] + ) + reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids] + + bulk_insert_references_ignore_conflicts(session, reference_rows) + restore_references_by_paths(session, absolute_path_list) + winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id) + + inserted_paths = { + path + for path in absolute_path_list + if path_to_asset_id[path] in inserted_asset_ids + } + losing_paths = inserted_paths - winning_paths + lost_asset_ids = [path_to_asset_id[path] for path in losing_paths] + + if lost_asset_ids: + delete_assets_by_ids(session, lost_asset_ids) + + if not winning_paths: + return BulkInsertResult( + inserted_refs=0, + won_paths=0, + lost_paths=len(losing_paths), + ) + + # Get reference IDs for winners + winning_ref_ids = [ + asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"] + for path in winning_paths + ] + inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids) + + tag_rows: list[TagRow] = [] + metadata_rows: list[MetadataRow] = [] + + if inserted_ref_ids: + for path in winning_paths: + asset_id = path_to_asset_id[path] + ref_data = asset_id_to_ref_data[asset_id] + ref_id = ref_data["reference_id"] + + if ref_id not in inserted_ref_ids: + continue + + for tag in ref_data["tags"]: + tag_rows.append( + { + "asset_reference_id": ref_id, + "tag_name": tag, + "origin": "automatic", + "added_at": current_time, + } + ) + + # Use extracted metadata for meta rows if available + extracted_metadata = ref_data.get("extracted_metadata") + if extracted_metadata: + metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id)) + elif ref_data["filename"]: + # Fallback: just store filename + metadata_rows.append( + { + "asset_reference_id": ref_id, + "key": "filename", + "ordinal": 0, + "val_str": ref_data["filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) + + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows) + + return BulkInsertResult( + inserted_refs=len(inserted_ref_ids), + won_paths=len(winning_paths), + lost_paths=len(losing_paths), + ) + + +def cleanup_unreferenced_assets(session: Session) -> int: + """Hard-delete unhashed assets with no active references. + + This is a destructive operation intended for explicit cleanup. + Only deletes assets where hash=None and all references are missing. + + Returns: + Number of assets deleted + """ + unreferenced_ids = get_unreferenced_unhashed_asset_ids(session) + return delete_assets_by_ids(session, unreferenced_ids) diff --git a/app/assets/services/file_utils.py b/app/assets/services/file_utils.py new file mode 100644 index 000000000..c47ebe460 --- /dev/null +++ b/app/assets/services/file_utils.py @@ -0,0 +1,70 @@ +import os + + +def get_mtime_ns(stat_result: os.stat_result) -> int: + """Extract mtime in nanoseconds from a stat result.""" + return getattr( + stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000) + ) + + +def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]: + """Get file size in bytes and mtime in nanoseconds.""" + st = os.stat(path, follow_symlinks=follow_symlinks) + return st.st_size, get_mtime_ns(st) + + +def verify_file_unchanged( + mtime_db: int | None, + size_db: int | None, + stat_result: os.stat_result, +) -> bool: + """Check if a file is unchanged based on mtime and size. + + Returns True if the file's mtime and size match the database values. + Returns False if mtime_db is None or values don't match. + + size_db=None means don't check size; 0 is a valid recorded size. + """ + if mtime_db is None: + return False + actual_mtime_ns = get_mtime_ns(stat_result) + if int(mtime_db) != int(actual_mtime_ns): + return False + if size_db is not None: + return int(stat_result.st_size) == int(size_db) + return True + + +def is_visible(name: str) -> bool: + """Return True if a file or directory name is visible (not hidden).""" + return not name.startswith(".") + + +def list_files_recursively(base_dir: str) -> list[str]: + """Recursively list all files in a directory, following symlinks.""" + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + # Track seen real directory identities to prevent circular symlink loops + seen_dirs: set[tuple[int, int]] = set() + for dirpath, subdirs, filenames in os.walk( + base_abs, topdown=True, followlinks=True + ): + try: + st = os.stat(dirpath) + dir_id = (st.st_dev, st.st_ino) + except OSError: + subdirs.clear() + continue + if dir_id in seen_dirs: + subdirs.clear() + continue + seen_dirs.add(dir_id) + subdirs[:] = [d for d in subdirs if is_visible(d)] + for name in filenames: + if not is_visible(name): + continue + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out diff --git a/app/assets/services/hashing.py b/app/assets/services/hashing.py new file mode 100644 index 000000000..92aee6402 --- /dev/null +++ b/app/assets/services/hashing.py @@ -0,0 +1,95 @@ +import io +import os +from contextlib import contextmanager +from dataclasses import dataclass +from typing import IO, Any, Callable, Iterator + +from blake3 import blake3 + +DEFAULT_CHUNK = 8 * 1024 * 1024 + +InterruptCheck = Callable[[], bool] + + +@dataclass +class HashCheckpoint: + """Saved state for resuming an interrupted hash computation.""" + + bytes_processed: int + hasher: Any # blake3 hasher instance + mtime_ns: int = 0 + file_size: int = 0 + + +@contextmanager +def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]: + """Yield (file_object, is_path) with appropriate setup/teardown.""" + if hasattr(fp, "read"): + seekable = getattr(fp, "seekable", lambda: False)() + orig_pos = None + if seekable: + try: + orig_pos = fp.tell() + if orig_pos != 0: + fp.seek(0) + except io.UnsupportedOperation: + orig_pos = None + try: + yield fp, False + finally: + if orig_pos is not None: + fp.seek(orig_pos) + else: + with open(os.fspath(fp), "rb") as f: + yield f, True + + +def compute_blake3_hash( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, + interrupt_check: InterruptCheck | None = None, + checkpoint: HashCheckpoint | None = None, +) -> tuple[str | None, HashCheckpoint | None]: + """Compute BLAKE3 hash of a file, with optional checkpoint support. + + Args: + fp: File path or file-like object + chunk_size: Size of chunks to read at a time + interrupt_check: Optional callable that returns True if the operation + should be interrupted (e.g. paused or cancelled). Must be + non-blocking so file handles are released immediately. Checked + between chunk reads. + checkpoint: Optional checkpoint to resume from (file paths only) + + Returns: + Tuple of (hex_digest, None) on completion, or + (None, checkpoint) on interruption (file paths only), or + (None, None) on interruption of a file object + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + with _open_for_hashing(fp) as (f, is_path): + if checkpoint is not None and is_path: + f.seek(checkpoint.bytes_processed) + h = checkpoint.hasher + bytes_processed = checkpoint.bytes_processed + else: + h = blake3() + bytes_processed = 0 + + while True: + if interrupt_check is not None and interrupt_check(): + if is_path: + return None, HashCheckpoint( + bytes_processed=bytes_processed, + hasher=h, + ) + return None, None + chunk = f.read(chunk_size) + if not chunk: + break + h.update(chunk) + bytes_processed += len(chunk) + + return h.hexdigest(), None diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py new file mode 100644 index 000000000..44d7aef36 --- /dev/null +++ b/app/assets/services/ingest.py @@ -0,0 +1,375 @@ +import contextlib +import logging +import mimetypes +import os +from typing import Any, Sequence + +from sqlalchemy.orm import Session + +import app.assets.services.hashing as hashing +from app.assets.database.queries import ( + add_tags_to_reference, + fetch_reference_and_asset, + get_asset_by_hash, + get_existing_asset_ids, + get_reference_by_file_path, + get_reference_tags, + get_or_create_reference, + remove_missing_tag_for_asset_id, + set_reference_metadata, + set_reference_tags, + upsert_asset, + upsert_reference, + validate_tags_exist, +) +from app.assets.helpers import normalize_tags +from app.assets.services.file_utils import get_size_and_mtime_ns +from app.assets.services.path_utils import ( + compute_relative_filename, + resolve_destination_from_tags, + validate_path_within_base, +) +from app.assets.services.schemas import ( + IngestResult, + RegisterAssetResult, + UploadResult, + UserMetadata, + extract_asset_data, + extract_reference_data, +) +from app.database.db import create_session + + +def _ingest_file_from_path( + abs_path: str, + asset_hash: str, + size_bytes: int, + mtime_ns: int, + mime_type: str | None = None, + info_name: str | None = None, + owner_id: str = "", + preview_id: str | None = None, + user_metadata: UserMetadata = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> IngestResult: + locator = os.path.abspath(abs_path) + user_metadata = user_metadata or {} + + asset_created = False + asset_updated = False + ref_created = False + ref_updated = False + reference_id: str | None = None + + with create_session() as session: + if preview_id: + if preview_id not in get_existing_asset_ids(session, [preview_id]): + preview_id = None + + asset, asset_created, asset_updated = upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=size_bytes, + mime_type=mime_type, + ) + + ref_created, ref_updated = upsert_reference( + session, + asset_id=asset.id, + file_path=locator, + name=info_name or os.path.basename(locator), + mtime_ns=mtime_ns, + owner_id=owner_id, + ) + + # Get the reference we just created/updated + ref = get_reference_by_file_path(session, locator) + if ref: + reference_id = ref.id + + if preview_id and ref.preview_id != preview_id: + ref.preview_id = preview_id + + norm = normalize_tags(list(tags)) + if norm: + if require_existing_tags: + validate_tags_exist(session, norm) + add_tags_to_reference( + session, + reference_id=reference_id, + tags=norm, + origin=tag_origin, + create_if_missing=not require_existing_tags, + ) + + _update_metadata_with_filename( + session, + reference_id=reference_id, + file_path=ref.file_path, + current_metadata=ref.user_metadata, + user_metadata=user_metadata, + ) + + try: + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + + session.commit() + + return IngestResult( + asset_created=asset_created, + asset_updated=asset_updated, + ref_created=ref_created, + ref_updated=ref_updated, + reference_id=reference_id, + ) + + +def _register_existing_asset( + asset_hash: str, + name: str, + user_metadata: UserMetadata = None, + tags: list[str] | None = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> RegisterAssetResult: + user_metadata = user_metadata or {} + + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"No asset with hash {asset_hash}") + + ref, ref_created = get_or_create_reference( + session, + asset_id=asset.id, + owner_id=owner_id, + name=name, + ) + + if not ref_created: + tag_names = get_reference_tags(session, reference_id=ref.id) + result = RegisterAssetResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created=False, + ) + session.commit() + return result + + new_meta = dict(user_metadata) + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta: + set_reference_metadata( + session, + reference_id=ref.id, + user_metadata=new_meta, + ) + + if tags is not None: + set_reference_tags( + session, + reference_id=ref.id, + tags=tags, + origin=tag_origin, + ) + + tag_names = get_reference_tags(session, reference_id=ref.id) + session.refresh(ref) + result = RegisterAssetResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created=True, + ) + session.commit() + + return result + + + +def _update_metadata_with_filename( + session: Session, + reference_id: str, + file_path: str | None, + current_metadata: dict | None, + user_metadata: dict[str, Any], +) -> None: + computed_filename = compute_relative_filename(file_path) if file_path else None + + current_meta = current_metadata or {} + new_meta = dict(current_meta) + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + set_reference_metadata( + session, + reference_id=reference_id, + user_metadata=new_meta, + ) + + +def _sanitize_filename(name: str | None, fallback: str) -> str: + n = os.path.basename((name or "").strip() or fallback) + return n if n else fallback + + +class HashMismatchError(Exception): + pass + + +class DependencyMissingError(Exception): + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +def upload_from_temp_path( + temp_path: str, + name: str | None = None, + tags: list[str] | None = None, + user_metadata: dict | None = None, + client_filename: str | None = None, + owner_id: str = "", + expected_hash: str | None = None, +) -> UploadResult: + try: + digest, _ = hashing.compute_blake3_hash(temp_path) + except ImportError as e: + raise DependencyMissingError(str(e)) + except Exception as e: + raise RuntimeError(f"failed to hash uploaded file: {e}") + asset_hash = "blake3:" + digest + + if expected_hash and asset_hash != expected_hash.strip().lower(): + raise HashMismatchError("Uploaded file hash does not match provided hash.") + + with create_session() as session: + existing = get_asset_by_hash(session, asset_hash=asset_hash) + + if existing is not None: + with contextlib.suppress(Exception): + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + display_name = _sanitize_filename(name or client_filename, fallback=digest) + result = _register_existing_asset( + asset_hash=asset_hash, + name=display_name, + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + return UploadResult( + ref=result.ref, + asset=result.asset, + tags=result.tags, + created_new=False, + ) + + if not tags: + raise ValueError("tags are required for new asset uploads") + base_dir, subdirs = resolve_destination_from_tags(tags) + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + src_for_ext = (client_filename or name or "").strip() + _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" + ext = _ext if 0 < len(_ext) <= 16 else "" + hashed_basename = f"{digest}{ext}" + dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) + validate_path_within_base(dest_abs, base_dir) + + content_type = ( + mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] + or mimetypes.guess_type(hashed_basename, strict=False)[0] + or "application/octet-stream" + ) + + try: + os.replace(temp_path, dest_abs) + except Exception as e: + raise RuntimeError(f"failed to move uploaded file into place: {e}") + + try: + size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs) + except OSError as e: + raise RuntimeError(f"failed to stat destination file: {e}") + + ingest_result = _ingest_file_from_path( + asset_hash=asset_hash, + abs_path=dest_abs, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_sanitize_filename(name or client_filename, fallback=digest), + owner_id=owner_id, + preview_id=None, + user_metadata=user_metadata or {}, + tags=tags, + tag_origin="manual", + require_existing_tags=False, + ) + reference_id = ingest_result.reference_id + if not reference_id: + raise RuntimeError("failed to create asset reference") + + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + ref, asset = pair + tag_names = get_reference_tags(session, reference_id=ref.id) + + return UploadResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created_new=ingest_result.asset_created, + ) + + +def create_from_hash( + hash_str: str, + name: str, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", +) -> UploadResult | None: + canonical = hash_str.strip().lower() + + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=canonical) + if not asset: + return None + + result = _register_existing_asset( + asset_hash=canonical, + name=_sanitize_filename( + name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical + ), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + + return UploadResult( + ref=result.ref, + asset=result.asset, + tags=result.tags, + created_new=False, + ) diff --git a/app/assets/services/metadata_extract.py b/app/assets/services/metadata_extract.py new file mode 100644 index 000000000..a004929bc --- /dev/null +++ b/app/assets/services/metadata_extract.py @@ -0,0 +1,327 @@ +"""Metadata extraction for asset scanning. + +Tier 1: Filesystem metadata (zero parsing) +Tier 2: Safetensors header metadata (fast JSON read only) +""" + +from __future__ import annotations + +import json +import logging +import mimetypes +import os +import struct +from dataclasses import dataclass +from typing import Any + +from utils.mime_types import init_mime_types + +init_mime_types() + +# Supported safetensors extensions +SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"}) + +# Maximum safetensors header size to read (8MB) +MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024 + + +@dataclass +class ExtractedMetadata: + """Metadata extracted from a file during scanning.""" + + # Tier 1: Filesystem (always available) + filename: str = "" + file_path: str = "" # Full absolute path to the file + content_length: int = 0 + content_type: str | None = None + format: str = "" # file extension without dot + + # Tier 2: Safetensors header (if available) + base_model: str | None = None + trained_words: list[str] | None = None + air: str | None = None # CivitAI AIR identifier + has_preview_images: bool = False + + # Source provenance (populated if embedded in safetensors) + source_url: str | None = None + source_arn: str | None = None + repo_url: str | None = None + preview_url: str | None = None + source_hash: str | None = None + + # HuggingFace specific + repo_id: str | None = None + revision: str | None = None + filepath: str | None = None + resolve_url: str | None = None + + def to_user_metadata(self) -> dict[str, Any]: + """Convert to user_metadata dict for AssetReference.user_metadata JSON field.""" + data: dict[str, Any] = { + "filename": self.filename, + "content_length": self.content_length, + "format": self.format, + } + if self.file_path: + data["file_path"] = self.file_path + if self.content_type: + data["content_type"] = self.content_type + + # Tier 2 fields + if self.base_model: + data["base_model"] = self.base_model + if self.trained_words: + data["trained_words"] = self.trained_words + if self.air: + data["air"] = self.air + if self.has_preview_images: + data["has_preview_images"] = True + + # Source provenance + if self.source_url: + data["source_url"] = self.source_url + if self.source_arn: + data["source_arn"] = self.source_arn + if self.repo_url: + data["repo_url"] = self.repo_url + if self.preview_url: + data["preview_url"] = self.preview_url + if self.source_hash: + data["source_hash"] = self.source_hash + + # HuggingFace + if self.repo_id: + data["repo_id"] = self.repo_id + if self.revision: + data["revision"] = self.revision + if self.filepath: + data["filepath"] = self.filepath + if self.resolve_url: + data["resolve_url"] = self.resolve_url + + return data + + def to_meta_rows(self, reference_id: str) -> list[dict]: + """Convert to asset_reference_meta rows for typed/indexed querying.""" + rows: list[dict] = [] + + def add_str(key: str, val: str | None, ordinal: int = 0) -> None: + if val: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": ordinal, + "val_str": val[:2048] if len(val) > 2048 else val, + "val_num": None, + "val_bool": None, + "val_json": None, + }) + + def add_num(key: str, val: int | float | None) -> None: + if val is not None: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": 0, + "val_str": None, + "val_num": val, + "val_bool": None, + "val_json": None, + }) + + def add_bool(key: str, val: bool | None) -> None: + if val is not None: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": 0, + "val_str": None, + "val_num": None, + "val_bool": val, + "val_json": None, + }) + + # Tier 1 + add_str("filename", self.filename) + add_num("content_length", self.content_length) + add_str("content_type", self.content_type) + add_str("format", self.format) + + # Tier 2 + add_str("base_model", self.base_model) + add_str("air", self.air) + has_previews = self.has_preview_images if self.has_preview_images else None + add_bool("has_preview_images", has_previews) + + # trained_words as multiple rows with ordinals + if self.trained_words: + for i, word in enumerate(self.trained_words[:100]): # limit to 100 words + add_str("trained_words", word, ordinal=i) + + # Source provenance + add_str("source_url", self.source_url) + add_str("source_arn", self.source_arn) + add_str("repo_url", self.repo_url) + add_str("preview_url", self.preview_url) + add_str("source_hash", self.source_hash) + + # HuggingFace + add_str("repo_id", self.repo_id) + add_str("revision", self.revision) + add_str("filepath", self.filepath) + add_str("resolve_url", self.resolve_url) + + return rows + + +def _read_safetensors_header( + path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE +) -> dict[str, Any] | None: + """Read only the JSON header from a safetensors file. + + This is very fast - reads 8 bytes for header length, then the JSON header. + No tensor data is loaded. + + Args: + path: Absolute path to safetensors file + max_size: Maximum header size to read (default 8MB) + + Returns: + Parsed header dict or None if failed + """ + try: + with open(path, "rb") as f: + header_bytes = f.read(8) + if len(header_bytes) < 8: + return None + length_of_header = struct.unpack(" max_size: + return None + header_data = f.read(length_of_header) + if len(header_data) < length_of_header: + return None + return json.loads(header_data.decode("utf-8")) + except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error): + return None + + +def _extract_safetensors_metadata( + header: dict[str, Any], meta: ExtractedMetadata +) -> None: + """Extract metadata from safetensors header __metadata__ section. + + Modifies meta in-place. + """ + st_meta = header.get("__metadata__", {}) + if not isinstance(st_meta, dict): + return + + # Common model metadata + meta.base_model = ( + st_meta.get("ss_base_model_version") + or st_meta.get("modelspec.base_model") + or st_meta.get("base_model") + ) + + # Trained words / trigger words + trained_words = st_meta.get("ss_tag_frequency") + if trained_words and isinstance(trained_words, str): + try: + tag_freq = json.loads(trained_words) + # Extract unique tags from all datasets + all_tags: set[str] = set() + for dataset_tags in tag_freq.values(): + if isinstance(dataset_tags, dict): + all_tags.update(dataset_tags.keys()) + if all_tags: + meta.trained_words = sorted(all_tags)[:100] + except json.JSONDecodeError: + pass + + # Direct trained_words field (some formats) + if not meta.trained_words: + tw = st_meta.get("trained_words") + if isinstance(tw, str): + try: + parsed = json.loads(tw) + if isinstance(parsed, list): + meta.trained_words = [str(x) for x in parsed] + else: + meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] + except json.JSONDecodeError: + meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] + elif isinstance(tw, list): + meta.trained_words = [str(x) for x in tw] + + # CivitAI AIR + meta.air = st_meta.get("air") or st_meta.get("modelspec.air") + + # Preview images (ssmd_cover_images) + cover_images = st_meta.get("ssmd_cover_images") + if cover_images: + meta.has_preview_images = True + + # Source provenance fields + meta.source_url = st_meta.get("source_url") + meta.source_arn = st_meta.get("source_arn") + meta.repo_url = st_meta.get("repo_url") + meta.preview_url = st_meta.get("preview_url") + meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash") + + # HuggingFace fields + meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id") + meta.revision = st_meta.get("revision") or st_meta.get("hf_revision") + meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath") + meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url") + + +def extract_file_metadata( + abs_path: str, + stat_result: os.stat_result | None = None, + relative_filename: str | None = None, +) -> ExtractedMetadata: + """Extract metadata from a file using tier 1 and tier 2 methods. + + Tier 1: Filesystem metadata from path and stat + Tier 2: Safetensors header parsing if applicable + + Args: + abs_path: Absolute path to the file + stat_result: Optional pre-fetched stat result (saves a syscall) + relative_filename: Optional relative filename to use instead of basename + (e.g., "flux/123/model.safetensors" for model paths) + + Returns: + ExtractedMetadata with all available fields populated + """ + meta = ExtractedMetadata() + + # Tier 1: Filesystem metadata + meta.filename = relative_filename or os.path.basename(abs_path) + meta.file_path = abs_path + _, ext = os.path.splitext(abs_path) + meta.format = ext.lstrip(".").lower() if ext else "" + + mime_type, _ = mimetypes.guess_type(abs_path) + meta.content_type = mime_type + + # Size from stat + if stat_result is None: + try: + stat_result = os.stat(abs_path, follow_symlinks=True) + except OSError: + pass + + if stat_result: + meta.content_length = stat_result.st_size + + # Tier 2: Safetensors header (if applicable and enabled) + if ext.lower() in SAFETENSORS_EXTENSIONS: + header = _read_safetensors_header(abs_path) + if header: + try: + _extract_safetensors_metadata(header, meta) + except Exception as e: + logging.debug("Safetensors meta extract failed %s: %s", abs_path, e) + + return meta diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py new file mode 100644 index 000000000..f5dd7f7fd --- /dev/null +++ b/app/assets/services/path_utils.py @@ -0,0 +1,167 @@ +import os +from pathlib import Path +from typing import Literal + +import folder_paths +from app.assets.helpers import normalize_tags + + +_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) + + +def get_comfy_models_folders() -> list[tuple[str, list[str]]]: + """Build list of (folder_name, base_paths[]) for all model locations. + + Includes every category registered in folder_names_and_paths, + regardless of whether its paths are under the main models_dir, + but excludes non-model entries like custom_nodes. + """ + targets: list[tuple[str, list[str]]] = [] + for name, values in folder_paths.folder_names_and_paths.items(): + if name in _NON_MODEL_FOLDER_NAMES: + continue + paths, _exts = values[0], values[1] + if paths: + targets.append((name, paths)) + return targets + + +def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: + """Validates and maps tags -> (base_dir, subdirs_for_fs)""" + if not tags: + raise ValueError("tags must not be empty") + root = tags[0].lower() + if root == "models": + if len(tags) < 2: + raise ValueError("at least two tags required for model asset") + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") + if not bases: + raise ValueError(f"no base path configured for category '{tags[1]}'") + base_dir = os.path.abspath(bases[0]) + raw_subdirs = tags[2:] + elif root == "input": + base_dir = os.path.abspath(folder_paths.get_input_directory()) + raw_subdirs = tags[1:] + elif root == "output": + base_dir = os.path.abspath(folder_paths.get_output_directory()) + raw_subdirs = tags[1:] + else: + raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'") + _sep_chars = frozenset(("/", "\\", os.sep)) + for i in raw_subdirs: + if i in (".", "..") or _sep_chars & set(i): + raise ValueError("invalid path component in tags") + + return base_dir, raw_subdirs if raw_subdirs else [] + + +def validate_path_within_base(candidate: str, base: str) -> None: + cand_abs = Path(os.path.abspath(candidate)) + base_abs = Path(os.path.abspath(base)) + if not cand_abs.is_relative_to(base_abs): + raise ValueError("destination escapes base directory") + + +def compute_relative_filename(file_path: str) -> str | None: + """ + Return the model's path relative to the last well-known folder (the model category), + using forward slashes, eg: + /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" + /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + + For non-model paths, returns None. + """ + try: + root_category, rel_path = get_asset_category_and_relative_path(file_path) + except ValueError: + return None + + p = Path(rel_path) + parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + if not parts: + return None + + if root_category == "models": + # parts[0] is the category ("checkpoints", "vae", etc) – drop it + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) + return "/".join(parts) # input/output: keep all parts + + +def get_asset_category_and_relative_path( + file_path: str, +) -> tuple[Literal["input", "output", "models"], str]: + """Determine which root category a file path belongs to. + + Categories: + - 'input': under folder_paths.get_input_directory() + - 'output': under folder_paths.get_output_directory() + - 'models': under any base path from get_comfy_models_folders() + + Returns: + (root_category, relative_path_inside_that_root) + + Raises: + ValueError: path does not belong to any known root. + """ + fp_abs = os.path.abspath(file_path) + + def _check_is_within(child: str, parent: str) -> bool: + return Path(child).is_relative_to(parent) + + def _compute_relative(child: str, parent: str) -> str: + # Normalize relative path, stripping any leading ".." components + # by anchoring to root (os.sep) then computing relpath back from it. + return os.path.relpath( + os.path.join(os.sep, os.path.relpath(child, parent)), os.sep + ) + + # 1) input + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _check_is_within(fp_abs, input_base): + return "input", _compute_relative(fp_abs, input_base) + + # 2) output + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _check_is_within(fp_abs, output_base): + return "output", _compute_relative(fp_abs, output_base) + + # 3) models (check deepest matching base to avoid ambiguity) + best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) + for bucket, bases in get_comfy_models_folders(): + for b in bases: + base_abs = os.path.abspath(b) + if not _check_is_within(fp_abs, base_abs): + continue + cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs)) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, bucket, rel_inside = best + combined = os.path.join(bucket, rel_inside) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + + raise ValueError( + f"Path is not within input, output, or configured model bases: {file_path}" + ) + + +def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: + """Return (name, tags) derived from a filesystem path. + + - name: base filename with extension + - tags: [root_category] + parent folder names in order + + Raises: + ValueError: path does not belong to any known root. + """ + root_category, some_path = get_asset_category_and_relative_path(file_path) + p = Path(some_path) + parent_parts = [ + part for part in p.parent.parts if part not in (".", "..", p.anchor) + ] + return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py new file mode 100644 index 000000000..8b1f1f4dc --- /dev/null +++ b/app/assets/services/schemas.py @@ -0,0 +1,109 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Any, NamedTuple + +from app.assets.database.models import Asset, AssetReference + +UserMetadata = dict[str, Any] | None + + +@dataclass(frozen=True) +class AssetData: + hash: str | None + size_bytes: int | None + mime_type: str | None + + +@dataclass(frozen=True) +class ReferenceData: + """Data transfer object for AssetReference.""" + + id: str + name: str + file_path: str | None + user_metadata: UserMetadata + preview_id: str | None + created_at: datetime + updated_at: datetime + last_access_time: datetime | None + + +@dataclass(frozen=True) +class AssetDetailResult: + ref: ReferenceData + asset: AssetData | None + tags: list[str] + + +@dataclass(frozen=True) +class RegisterAssetResult: + ref: ReferenceData + asset: AssetData + tags: list[str] + created: bool + + +@dataclass(frozen=True) +class IngestResult: + asset_created: bool + asset_updated: bool + ref_created: bool + ref_updated: bool + reference_id: str | None + + +class TagUsage(NamedTuple): + name: str + tag_type: str + count: int + + +@dataclass(frozen=True) +class AssetSummaryData: + ref: ReferenceData + asset: AssetData | None + tags: list[str] + + +@dataclass(frozen=True) +class ListAssetsResult: + items: list[AssetSummaryData] + total: int + + +@dataclass(frozen=True) +class DownloadResolutionResult: + abs_path: str + content_type: str + download_name: str + + +@dataclass(frozen=True) +class UploadResult: + ref: ReferenceData + asset: AssetData + tags: list[str] + created_new: bool + + +def extract_reference_data(ref: AssetReference) -> ReferenceData: + return ReferenceData( + id=ref.id, + name=ref.name, + file_path=ref.file_path, + user_metadata=ref.user_metadata, + preview_id=ref.preview_id, + created_at=ref.created_at, + updated_at=ref.updated_at, + last_access_time=ref.last_access_time, + ) + + +def extract_asset_data(asset: Asset | None) -> AssetData | None: + if asset is None: + return None + return AssetData( + hash=asset.hash, + size_bytes=asset.size_bytes, + mime_type=asset.mime_type, + ) diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py new file mode 100644 index 000000000..28900464d --- /dev/null +++ b/app/assets/services/tagging.py @@ -0,0 +1,75 @@ +from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, + add_tags_to_reference, + get_reference_with_owner_check, + list_tags_with_usage, + remove_tags_from_reference, +) +from app.assets.services.schemas import TagUsage +from app.database.db import create_session + + +def apply_tags( + reference_id: str, + tags: list[str], + origin: str = "manual", + owner_id: str = "", +) -> AddTagsResult: + with create_session() as session: + ref_row = get_reference_with_owner_check(session, reference_id, owner_id) + + result = add_tags_to_reference( + session, + reference_id=reference_id, + tags=tags, + origin=origin, + create_if_missing=True, + reference_row=ref_row, + ) + session.commit() + + return result + + +def remove_tags( + reference_id: str, + tags: list[str], + owner_id: str = "", +) -> RemoveTagsResult: + with create_session() as session: + get_reference_with_owner_check(session, reference_id, owner_id) + + result = remove_tags_from_reference( + session, + reference_id=reference_id, + tags=tags, + ) + session.commit() + + return result + + +def list_tags( + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, + owner_id: str = "", +) -> tuple[list[TagUsage], int]: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + with create_session() as session: + rows, total = list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + owner_id=owner_id, + ) + + return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total diff --git a/app/database/db.py b/app/database/db.py index 1de8b80ed..0aab09a49 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -3,6 +3,7 @@ import os import shutil from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message +from filelock import FileLock, Timeout from comfy.cli_args import args _DB_AVAILABLE = False @@ -14,8 +15,12 @@ try: from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory - from sqlalchemy import create_engine + from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker + from sqlalchemy.pool import StaticPool + + from app.database.models import Base + import app.assets.database.models # noqa: F401 — register models with Base.metadata _DB_AVAILABLE = True except ImportError as e: @@ -65,9 +70,69 @@ def get_db_path(): raise ValueError(f"Unsupported database URL '{url}'.") +_db_lock = None + +def _acquire_file_lock(db_path): + """Acquire an OS-level file lock to prevent multi-process access. + + Uses filelock for cross-platform support (macOS, Linux, Windows). + The OS automatically releases the lock when the process exits, even on crashes. + """ + global _db_lock + lock_path = db_path + ".lock" + _db_lock = FileLock(lock_path) + try: + _db_lock.acquire(timeout=0) + except Timeout: + raise RuntimeError( + f"Could not acquire lock on database '{db_path}'. " + "Another ComfyUI process may already be using it. " + "Use --database-url to specify a separate database file." + ) + + +def _is_memory_db(db_url): + """Check if the database URL refers to an in-memory SQLite database.""" + return db_url in ("sqlite:///:memory:", "sqlite://") + + def init_db(): db_url = args.database_url logging.debug(f"Database URL: {db_url}") + + if _is_memory_db(db_url): + _init_memory_db(db_url) + else: + _init_file_db(db_url) + + +def _init_memory_db(db_url): + """Initialize an in-memory SQLite database using metadata.create_all. + + Alembic migrations don't work with in-memory SQLite because each + connection gets its own separate database — tables created by Alembic's + internal connection are lost immediately. + """ + engine = create_engine( + db_url, + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + Base.metadata.create_all(engine) + + global Session + Session = sessionmaker(bind=engine) + + +def _init_file_db(db_url): + """Initialize a file-backed SQLite database using Alembic migrations.""" db_path = get_db_path() db_exists = os.path.exists(db_path) @@ -75,6 +140,14 @@ def init_db(): # Check if we need to upgrade engine = create_engine(db_url) + + # Enable foreign key enforcement for SQLite + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + conn = engine.connect() context = MigrationContext.configure(conn) @@ -104,6 +177,12 @@ def init_db(): logging.exception("Error upgrading database: ") raise e + # Acquire an OS-level file lock after migrations are complete. + # Alembic uses its own connection, so we must wait until it's done + # before locking — otherwise our own lock blocks the migration. + conn.close() + _acquire_file_lock(db_path) + global Session Session = sessionmaker(bind=engine) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 13079c7bc..e9832acaf 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -232,7 +232,7 @@ database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") -parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") +parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index a90a5ca40..9f6918315 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -15,6 +15,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = { "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, "node_replacements": True, + "assets": args.enable_assets, } diff --git a/main.py b/main.py index 0f58d57b8..a8fc1a28d 100644 --- a/main.py +++ b/main.py @@ -7,14 +7,16 @@ import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger -from app.assets.scanner import seed_assets +from app.assets.seeder import asset_seeder import itertools import utils.extra_config +from utils.mime_types import init_mime_types import logging import sys from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags +from app.database.db import init_db, dependencies_available if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. @@ -161,6 +163,7 @@ def execute_prestartup_script(): logging.info("") apply_custom_paths() +init_mime_types() if args.enable_manager: comfyui_manager.prestartup() @@ -258,6 +261,7 @@ def prompt_worker(q, server_instance): for k in sensitive: extra_data[k] = sensitive[k] + asset_seeder.pause() e.execute(item[2], prompt_id, extra_data, item[4]) need_gc = True @@ -302,6 +306,7 @@ def prompt_worker(q, server_instance): last_gc_collect = current_time need_gc = False hook_breaker_ac10a0.restore_functions() + asset_seeder.resume() async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): @@ -352,12 +357,29 @@ def cleanup_temp(): def setup_database(): try: - from app.database.db import init_db, dependencies_available if dependencies_available(): init_db() - if not args.disable_assets_autoscan: - seed_assets(["models"], enable_logging=True) + if args.enable_assets: + if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True): + logging.info("Background asset scan initiated for models, input, output") except Exception as e: + if "database is locked" in str(e): + logging.error( + "Database is locked. Another ComfyUI process is already using this database.\n" + "To resolve this, specify a separate database file for this instance:\n" + " --database-url sqlite:///path/to/another.db" + ) + sys.exit(1) + if args.enable_assets: + logging.error( + f"Failed to initialize database: {e}\n" + "The --enable-assets flag requires a working database connection.\n" + "To resolve this, try one of the following:\n" + " 1. Install the latest requirements: pip install -r requirements.txt\n" + " 2. Specify an alternative database URL: --database-url sqlite:///path/to/your.db\n" + " 3. Use an in-memory database: --database-url sqlite:///:memory:" + ) + sys.exit(1) logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") @@ -440,5 +462,6 @@ if __name__ == "__main__": event_loop.run_until_complete(x) except KeyboardInterrupt: logging.info("\nStopped server") - - cleanup_temp() + finally: + asset_seeder.shutdown() + cleanup_temp() diff --git a/requirements.txt b/requirements.txt index dc9a9ded0..9527135ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,11 +20,13 @@ tqdm psutil alembic SQLAlchemy +filelock av>=14.2.0 comfy-kitchen>=0.2.7 comfy-aimdo>=0.2.7 requests simpleeval>=1.0.0 +blake3 #non essential dependencies: kornia>=0.7.1 diff --git a/server.py b/server.py index 275bce5a7..76904ebc9 100644 --- a/server.py +++ b/server.py @@ -33,8 +33,8 @@ import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal -from app.assets.scanner import seed_assets -from app.assets.api.routes import register_assets_system +from app.assets.seeder import asset_seeder +from app.assets.api.routes import register_assets_routes from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -197,10 +197,6 @@ class PromptServer(): def __init__(self, loop): PromptServer.instance = self - mimetypes.init() - mimetypes.add_type('application/javascript; charset=utf-8', '.js') - mimetypes.add_type('image/webp', '.webp') - self.user_manager = UserManager() self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() @@ -239,7 +235,11 @@ class PromptServer(): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") - register_assets_system(self.app, self.user_manager) + if args.enable_assets: + register_assets_routes(self.app, self.user_manager) + else: + register_assets_routes(self.app) + asset_seeder.disable() routes = web.RouteTableDef() self.routes = routes self.last_node_id = None @@ -697,10 +697,7 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): - try: - seed_assets(["models"]) - except Exception as e: - logging.error(f"Failed to seed assets: {e}") + asset_seeder.start(roots=("models", "input", "output")) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py index 0a57dd7b5..6c5c56113 100644 --- a/tests-unit/assets_test/conftest.py +++ b/tests-unit/assets_test/conftest.py @@ -108,7 +108,7 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest) "main.py", f"--base-directory={str(comfy_tmp_base_dir)}", f"--database-url={db_url}", - "--disable-assets-autoscan", + "--enable-assets", "--listen", "127.0.0.1", "--port", @@ -212,7 +212,7 @@ def asset_factory(http: requests.Session, api_base: str): for aid in created: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}", timeout=30) + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) @pytest.fixture @@ -258,14 +258,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str): break for aid in ids: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}", timeout=30) - - -def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: - """Force a fast sync/seed pass by calling the seed endpoint.""" - session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30) - time.sleep(0.2) - - -def get_asset_filename(asset_hash: str, extension: str) -> str: - return asset_hash.removeprefix("blake3:") + extension + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) diff --git a/tests-unit/assets_test/helpers.py b/tests-unit/assets_test/helpers.py new file mode 100644 index 000000000..770e011f4 --- /dev/null +++ b/tests-unit/assets_test/helpers.py @@ -0,0 +1,28 @@ +"""Helper functions for assets integration tests.""" +import time + +import requests + + +def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: + """Force a synchronous sync/seed pass by calling the seed endpoint with wait=true. + + Retries on 409 (already running) until the previous scan finishes. + """ + deadline = time.monotonic() + 60 + while True: + r = session.post( + base_url + "/api/assets/seed?wait=true", + json={"roots": ["models", "input", "output"]}, + timeout=60, + ) + if r.status_code != 409: + assert r.status_code == 200, f"seed endpoint returned {r.status_code}: {r.text}" + return + if time.monotonic() > deadline: + raise TimeoutError("seed endpoint stuck in 409 (already running)") + time.sleep(0.25) + + +def get_asset_filename(asset_hash: str, extension: str) -> str: + return asset_hash.removeprefix("blake3:") + extension diff --git a/tests-unit/assets_test/queries/conftest.py b/tests-unit/assets_test/queries/conftest.py new file mode 100644 index 000000000..4ca0e86a9 --- /dev/null +++ b/tests-unit/assets_test/queries/conftest.py @@ -0,0 +1,20 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import Base + + +@pytest.fixture +def session(): + """In-memory SQLite session for fast unit tests.""" + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + with Session(engine) as sess: + yield sess + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(): + """Override parent autouse fixture - query tests don't need server cleanup.""" + yield diff --git a/tests-unit/assets_test/queries/test_asset.py b/tests-unit/assets_test/queries/test_asset.py new file mode 100644 index 000000000..08f84cd11 --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset.py @@ -0,0 +1,144 @@ +import uuid + +import pytest +from sqlalchemy.orm import Session + +from app.assets.helpers import get_utc_now +from app.assets.database.models import Asset +from app.assets.database.queries import ( + asset_exists_by_hash, + get_asset_by_hash, + upsert_asset, + bulk_insert_assets, +) + + +class TestAssetExistsByHash: + @pytest.mark.parametrize( + "setup_hash,query_hash,expected", + [ + (None, "nonexistent", False), # No asset exists + ("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash + (None, "", False), # Null hash in DB doesn't match empty string + ], + ids=["nonexistent", "existing", "null_hash_no_match"], + ) + def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected): + if setup_hash is not None or query_hash == "": + asset = Asset(hash=setup_hash, size_bytes=100) + session.add(asset) + session.commit() + + assert asset_exists_by_hash(session, asset_hash=query_hash) is expected + + +class TestGetAssetByHash: + @pytest.mark.parametrize( + "setup_hash,query_hash,should_find", + [ + (None, "nonexistent", False), + ("blake3:def456", "blake3:def456", True), + ], + ids=["nonexistent", "existing"], + ) + def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find): + if setup_hash is not None: + asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png") + session.add(asset) + session.commit() + + result = get_asset_by_hash(session, asset_hash=query_hash) + if should_find: + assert result is not None + assert result.size_bytes == 200 + assert result.mime_type == "image/png" + else: + assert result is None + + +class TestUpsertAsset: + @pytest.mark.parametrize( + "first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime", + [ + # New asset creation + (None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"), + # Existing asset, same values - no update + (500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"), + # Existing asset with size 0, update with new values + (0, None, 2048, "image/png", False, True, 2048, "image/png"), + # Existing asset, second call with size 0 - no update + (1000, None, 0, None, False, False, 1000, None), + ], + ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"], + ) + def test_upsert_scenarios( + self, + session: Session, + first_size, + first_mime, + second_size, + second_mime, + expect_created, + expect_updated, + final_size, + final_mime, + ): + asset_hash = f"blake3:test_{first_size}_{second_size}" + + # First upsert (if first_size is not None, we're testing the second call) + if first_size is not None: + upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=first_size, + mime_type=first_mime, + ) + session.commit() + + # The upsert call we're testing + asset, created, updated = upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=second_size, + mime_type=second_mime, + ) + session.commit() + + assert created is expect_created + assert updated is expect_updated + assert asset.size_bytes == final_size + assert asset.mime_type == final_mime + + +class TestBulkInsertAssets: + def test_inserts_multiple_assets(self, session: Session): + now = get_utc_now() + rows = [ + {"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": now}, + {"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": now}, + {"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": now}, + ] + bulk_insert_assets(session, rows) + session.commit() + + assets = session.query(Asset).all() + assert len(assets) == 3 + hashes = {a.hash for a in assets} + assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"} + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_assets(session, []) + session.commit() + assert session.query(Asset).count() == 0 + + def test_handles_large_batch(self, session: Session): + """Test chunking logic with more rows than MAX_BIND_PARAMS allows.""" + now = get_utc_now() + rows = [ + {"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": now} + for i in range(200) + ] + bulk_insert_assets(session, rows) + session.commit() + + assert session.query(Asset).count() == 200 diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py new file mode 100644 index 000000000..8f6c7fcdb --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -0,0 +1,517 @@ +import time +import uuid +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta +from app.assets.database.queries import ( + reference_exists_for_asset_id, + get_reference_by_id, + insert_reference, + get_or_create_reference, + update_reference_timestamps, + list_references_page, + fetch_reference_asset_and_tags, + fetch_reference_and_asset, + update_reference_access_time, + set_reference_metadata, + delete_reference_by_id, + set_reference_preview, + bulk_insert_references_ignore_conflicts, + get_reference_ids_by_ids, + ensure_tags_exist, + add_tags_to_reference, +) +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestReferenceExistsForAssetId: + def test_returns_false_when_no_reference(self, session: Session): + asset = _make_asset(session, "hash1") + assert reference_exists_for_asset_id(session, asset_id=asset.id) is False + + def test_returns_true_when_reference_exists(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset) + assert reference_exists_for_asset_id(session, asset_id=asset.id) is True + + +class TestGetReferenceById: + def test_returns_none_for_nonexistent(self, session: Session): + assert get_reference_by_id(session, reference_id="nonexistent") is None + + def test_returns_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="myfile.txt") + + result = get_reference_by_id(session, reference_id=ref.id) + assert result is not None + assert result.name == "myfile.txt" + + +class TestListReferencesPage: + def test_empty_db(self, session: Session): + refs, tag_map, total = list_references_page(session) + assert refs == [] + assert tag_map == {} + assert total == 0 + + def test_returns_references_with_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["alpha", "beta"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"]) + session.commit() + + refs, tag_map, total = list_references_page(session) + assert len(refs) == 1 + assert refs[0].id == ref.id + assert set(tag_map[ref.id]) == {"alpha", "beta"} + assert total == 1 + + def test_name_contains_filter(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="model_v1.safetensors") + _make_reference(session, asset, name="config.json") + session.commit() + + refs, _, total = list_references_page(session, name_contains="model") + assert total == 1 + assert refs[0].name == "model_v1.safetensors" + + def test_owner_visibility(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="public", owner_id="") + _make_reference(session, asset, name="private", owner_id="user1") + session.commit() + + # Empty owner sees only public + refs, _, total = list_references_page(session, owner_id="") + assert total == 1 + assert refs[0].name == "public" + + # Owner sees both + refs, _, total = list_references_page(session, owner_id="user1") + assert total == 2 + + def test_include_tags_filter(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="tagged") + _make_reference(session, asset, name="untagged") + ensure_tags_exist(session, ["wanted"]) + add_tags_to_reference(session, reference_id=ref1.id, tags=["wanted"]) + session.commit() + + refs, _, total = list_references_page(session, include_tags=["wanted"]) + assert total == 1 + assert refs[0].name == "tagged" + + def test_exclude_tags_filter(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="keep") + ref_exclude = _make_reference(session, asset, name="exclude") + ensure_tags_exist(session, ["bad"]) + add_tags_to_reference(session, reference_id=ref_exclude.id, tags=["bad"]) + session.commit() + + refs, _, total = list_references_page(session, exclude_tags=["bad"]) + assert total == 1 + assert refs[0].name == "keep" + + def test_sorting(self, session: Session): + asset = _make_asset(session, "hash1", size=100) + asset2 = _make_asset(session, "hash2", size=500) + _make_reference(session, asset, name="small") + _make_reference(session, asset2, name="large") + session.commit() + + refs, _, _ = list_references_page(session, sort="size", order="desc") + assert refs[0].name == "large" + + refs, _, _ = list_references_page(session, sort="name", order="asc") + assert refs[0].name == "large" + + +class TestFetchReferenceAssetAndTags: + def test_returns_none_for_nonexistent(self, session: Session): + result = fetch_reference_asset_and_tags(session, "nonexistent") + assert result is None + + def test_returns_tuple(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["tag1"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["tag1"]) + session.commit() + + result = fetch_reference_asset_and_tags(session, ref.id) + assert result is not None + ret_ref, ret_asset, ret_tags = result + assert ret_ref.id == ref.id + assert ret_asset.id == asset.id + assert ret_tags == ["tag1"] + + +class TestFetchReferenceAndAsset: + def test_returns_none_for_nonexistent(self, session: Session): + result = fetch_reference_and_asset(session, reference_id="nonexistent") + assert result is None + + def test_returns_tuple(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + result = fetch_reference_and_asset(session, reference_id=ref.id) + assert result is not None + ret_ref, ret_asset = result + assert ret_ref.id == ref.id + assert ret_asset.id == asset.id + + +class TestUpdateReferenceAccessTime: + def test_updates_last_access_time(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + original_time = ref.last_access_time + session.commit() + + import time + time.sleep(0.01) + + update_reference_access_time(session, reference_id=ref.id) + session.commit() + + session.refresh(ref) + assert ref.last_access_time > original_time + + +class TestDeleteReferenceById: + def test_deletes_existing(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + result = delete_reference_by_id(session, reference_id=ref.id, owner_id="") + assert result is True + assert get_reference_by_id(session, reference_id=ref.id) is None + + def test_returns_false_for_nonexistent(self, session: Session): + result = delete_reference_by_id(session, reference_id="nonexistent", owner_id="") + assert result is False + + def test_respects_owner_visibility(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + result = delete_reference_by_id(session, reference_id=ref.id, owner_id="user2") + assert result is False + assert get_reference_by_id(session, reference_id=ref.id) is not None + + +class TestSetReferencePreview: + def test_sets_preview(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + session.commit() + + set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id) + session.commit() + + session.refresh(ref) + assert ref.preview_id == preview_asset.id + + def test_clears_preview(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + ref.preview_id = preview_asset.id + session.commit() + + set_reference_preview(session, reference_id=ref.id, preview_asset_id=None) + session.commit() + + session.refresh(ref) + assert ref.preview_id is None + + def test_raises_for_nonexistent_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None) + + def test_raises_for_nonexistent_preview(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + with pytest.raises(ValueError, match="Preview Asset"): + set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent") + + +class TestInsertReference: + def test_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = insert_reference( + session, asset_id=asset.id, owner_id="user1", name="test.bin" + ) + session.commit() + + assert ref is not None + assert ref.name == "test.bin" + assert ref.owner_id == "user1" + + def test_allows_duplicate_names(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = insert_reference(session, asset_id=asset.id, owner_id="user1", name="dup.bin") + session.commit() + + # Duplicate names are now allowed + ref2 = insert_reference( + session, asset_id=asset.id, owner_id="user1", name="dup.bin" + ) + session.commit() + + assert ref1 is not None + assert ref2 is not None + assert ref1.id != ref2.id + + +class TestGetOrCreateReference: + def test_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref, created = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="new.bin" + ) + session.commit() + + assert created is True + assert ref.name == "new.bin" + + def test_always_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref1, created1 = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="existing.bin" + ) + session.commit() + + # Duplicate names are allowed, so always creates new + ref2, created2 = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="existing.bin" + ) + session.commit() + + assert created1 is True + assert created2 is True + assert ref1.id != ref2.id + + +class TestUpdateReferenceTimestamps: + def test_updates_timestamps(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + original_updated_at = ref.updated_at + session.commit() + + time.sleep(0.01) + update_reference_timestamps(session, ref) + session.commit() + + session.refresh(ref) + assert ref.updated_at > original_updated_at + + def test_updates_preview_id(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + session.commit() + + update_reference_timestamps(session, ref, preview_id=preview_asset.id) + session.commit() + + session.refresh(ref) + assert ref.preview_id == preview_asset.id + + +class TestSetReferenceMetadata: + def test_sets_metadata(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"key": "value"} + ) + session.commit() + + session.refresh(ref) + assert ref.user_metadata == {"key": "value"} + # Check metadata table + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "key" + assert meta[0].val_str == "value" + + def test_replaces_existing_metadata(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"old": "data"} + ) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"new": "data"} + ) + session.commit() + + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "new" + + def test_clears_metadata_with_empty_dict(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"key": "value"} + ) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={} + ) + session.commit() + + session.refresh(ref) + assert ref.user_metadata == {} + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 0 + + def test_raises_for_nonexistent(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_reference_metadata( + session, reference_id="nonexistent", user_metadata={"key": "value"} + ) + + +class TestBulkInsertReferencesIgnoreConflicts: + def test_inserts_multiple_references(self, session: Session): + asset = _make_asset(session, "hash1") + now = get_utc_now() + rows = [ + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "bulk1.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "bulk2.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + refs = session.query(AssetReference).all() + assert len(refs) == 2 + + def test_allows_duplicate_names(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="existing.bin", owner_id="") + session.commit() + + now = get_utc_now() + rows = [ + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "existing.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "new.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + # Duplicate names allowed, so all 3 rows exist + refs = session.query(AssetReference).all() + assert len(refs) == 3 + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_references_ignore_conflicts(session, []) + assert session.query(AssetReference).count() == 0 + + +class TestGetReferenceIdsByIds: + def test_returns_existing_ids(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="a.bin") + ref2 = _make_reference(session, asset, name="b.bin") + session.commit() + + found = get_reference_ids_by_ids(session, [ref1.id, ref2.id, "nonexistent"]) + + assert found == {ref1.id, ref2.id} + + def test_empty_list_returns_empty(self, session: Session): + found = get_reference_ids_by_ids(session, []) + assert found == set() diff --git a/tests-unit/assets_test/queries/test_cache_state.py b/tests-unit/assets_test/queries/test_cache_state.py new file mode 100644 index 000000000..ead60e570 --- /dev/null +++ b/tests-unit/assets_test/queries/test_cache_state.py @@ -0,0 +1,499 @@ +"""Tests for cache_state (AssetReference file path) query functions.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ( + list_references_by_asset_id, + upsert_reference, + get_unreferenced_unhashed_asset_ids, + delete_assets_by_ids, + get_references_for_prefixes, + bulk_update_needs_verify, + delete_references_by_ids, + delete_orphaned_seed_asset, + bulk_insert_references_ignore_conflicts, + get_references_by_paths_and_asset_ids, + mark_references_missing_outside_prefixes, + restore_references_by_paths, +) +from app.assets.helpers import select_best_live_path, get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + file_path: str, + name: str = "test", + mtime_ns: int | None = None, + needs_verify: bool = False, +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + asset_id=asset.id, + file_path=file_path, + name=name, + mtime_ns=mtime_ns, + needs_verify=needs_verify, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestListReferencesByAssetId: + def test_returns_empty_for_no_references(self, session: Session): + asset = _make_asset(session, "hash1") + refs = list_references_by_asset_id(session, asset_id=asset.id) + assert list(refs) == [] + + def test_returns_references_for_asset(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/path/a.bin", name="a") + _make_reference(session, asset, "/path/b.bin", name="b") + session.commit() + + refs = list_references_by_asset_id(session, asset_id=asset.id) + paths = [r.file_path for r in refs] + assert set(paths) == {"/path/a.bin", "/path/b.bin"} + + def test_does_not_return_other_assets_references(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + _make_reference(session, asset1, "/path/asset1.bin", name="a1") + _make_reference(session, asset2, "/path/asset2.bin", name="a2") + session.commit() + + refs = list_references_by_asset_id(session, asset_id=asset1.id) + paths = [r.file_path for r in refs] + assert paths == ["/path/asset1.bin"] + + +class TestSelectBestLivePath: + def test_returns_empty_for_empty_list(self): + result = select_best_live_path([]) + assert result == "" + + def test_returns_empty_when_no_files_exist(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, "/nonexistent/path.bin") + session.commit() + + result = select_best_live_path([ref]) + assert result == "" + + def test_prefers_verified_path(self, session: Session, tmp_path): + """needs_verify=False should be preferred.""" + asset = _make_asset(session, "hash1") + + verified_file = tmp_path / "verified.bin" + verified_file.write_bytes(b"data") + + unverified_file = tmp_path / "unverified.bin" + unverified_file.write_bytes(b"data") + + ref_verified = _make_reference( + session, asset, str(verified_file), name="verified", needs_verify=False + ) + ref_unverified = _make_reference( + session, asset, str(unverified_file), name="unverified", needs_verify=True + ) + session.commit() + + refs = [ref_unverified, ref_verified] + result = select_best_live_path(refs) + assert result == str(verified_file) + + def test_falls_back_to_existing_unverified(self, session: Session, tmp_path): + """If all references need verification, return first existing path.""" + asset = _make_asset(session, "hash1") + + existing_file = tmp_path / "exists.bin" + existing_file.write_bytes(b"data") + + ref = _make_reference(session, asset, str(existing_file), needs_verify=True) + session.commit() + + result = select_best_live_path([ref]) + assert result == str(existing_file) + + +class TestSelectBestLivePathWithMocking: + def test_handles_missing_file_path_attr(self): + """Gracefully handle references with None file_path.""" + + class MockRef: + file_path = None + needs_verify = False + + result = select_best_live_path([MockRef()]) + assert result == "" + + +class TestUpsertReference: + @pytest.mark.parametrize( + "initial_mtime,second_mtime,expect_created,expect_updated,final_mtime", + [ + # New reference creation + (None, 12345, True, False, 12345), + # Existing reference, same mtime - no update + (100, 100, False, False, 100), + # Existing reference, different mtime - update + (100, 200, False, True, 200), + ], + ids=["new_reference", "existing_no_change", "existing_update_mtime"], + ) + def test_upsert_scenarios( + self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime + ): + asset = _make_asset(session, "hash1") + file_path = f"/path_{initial_mtime}_{second_mtime}.bin" + name = f"file_{initial_mtime}_{second_mtime}" + + # Create initial reference if needed + if initial_mtime is not None: + upsert_reference(session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=initial_mtime) + session.commit() + + # The upsert call we're testing + created, updated = upsert_reference( + session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=second_mtime + ) + session.commit() + + assert created is expect_created + assert updated is expect_updated + ref = session.query(AssetReference).filter_by(file_path=file_path).one() + assert ref.mtime_ns == final_mtime + + def test_upsert_restores_missing_reference(self, session: Session): + """Upserting a reference that was marked missing should restore it.""" + asset = _make_asset(session, "hash1") + file_path = "/restored/file.bin" + + ref = _make_reference(session, asset, file_path, mtime_ns=100) + ref.is_missing = True + session.commit() + + created, updated = upsert_reference( + session, asset_id=asset.id, file_path=file_path, name="restored", mtime_ns=100 + ) + session.commit() + + assert created is False + assert updated is True + restored_ref = session.query(AssetReference).filter_by(file_path=file_path).one() + assert restored_ref.is_missing is False + + +class TestRestoreReferencesByPaths: + def test_restores_missing_references(self, session: Session): + asset = _make_asset(session, "hash1") + missing_path = "/missing/file.bin" + active_path = "/active/file.bin" + + missing_ref = _make_reference(session, asset, missing_path, name="missing") + missing_ref.is_missing = True + _make_reference(session, asset, active_path, name="active") + session.commit() + + restored = restore_references_by_paths(session, [missing_path]) + session.commit() + + assert restored == 1 + ref = session.query(AssetReference).filter_by(file_path=missing_path).one() + assert ref.is_missing is False + + def test_empty_list_restores_nothing(self, session: Session): + restored = restore_references_by_paths(session, []) + assert restored == 0 + + +class TestMarkReferencesMissingOutsidePrefixes: + def test_marks_references_missing_outside_prefixes(self, session: Session, tmp_path): + asset = _make_asset(session, "hash1") + valid_dir = tmp_path / "valid" + valid_dir.mkdir() + invalid_dir = tmp_path / "invalid" + invalid_dir.mkdir() + + valid_path = str(valid_dir / "file.bin") + invalid_path = str(invalid_dir / "file.bin") + + _make_reference(session, asset, valid_path, name="valid") + _make_reference(session, asset, invalid_path, name="invalid") + session.commit() + + marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)]) + session.commit() + + assert marked == 1 + all_refs = session.query(AssetReference).all() + assert len(all_refs) == 2 + + valid_ref = next(r for r in all_refs if r.file_path == valid_path) + invalid_ref = next(r for r in all_refs if r.file_path == invalid_path) + assert valid_ref.is_missing is False + assert invalid_ref.is_missing is True + + def test_empty_prefixes_marks_nothing(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/some/path.bin") + session.commit() + + marked = mark_references_missing_outside_prefixes(session, []) + + assert marked == 0 + + +class TestGetUnreferencedUnhashedAssetIds: + def test_returns_unreferenced_unhashed_assets(self, session: Session): + # Unhashed asset (hash=None) with no references (no file_path) + no_refs = _make_asset(session, hash_val=None) + # Unhashed asset with active reference (not unreferenced) + with_active_ref = _make_asset(session, hash_val=None) + _make_reference(session, with_active_ref, "/has/ref.bin", name="has_ref") + # Unhashed asset with only missing reference (should be unreferenced) + with_missing_ref = _make_asset(session, hash_val=None) + missing_ref = _make_reference(session, with_missing_ref, "/missing/ref.bin", name="missing_ref") + missing_ref.is_missing = True + # Regular asset (hash not None) - should not be returned + _make_asset(session, hash_val="blake3:regular") + session.commit() + + unreferenced = get_unreferenced_unhashed_asset_ids(session) + + assert no_refs.id in unreferenced + assert with_missing_ref.id in unreferenced + assert with_active_ref.id not in unreferenced + + +class TestDeleteAssetsByIds: + def test_deletes_assets_and_references(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/test/path.bin", name="test") + session.commit() + + deleted = delete_assets_by_ids(session, [asset.id]) + session.commit() + + assert deleted == 1 + assert session.query(Asset).count() == 0 + assert session.query(AssetReference).count() == 0 + + def test_empty_list_deletes_nothing(self, session: Session): + _make_asset(session, "hash1") + session.commit() + + deleted = delete_assets_by_ids(session, []) + + assert deleted == 0 + assert session.query(Asset).count() == 1 + + +class TestGetReferencesForPrefixes: + def test_returns_references_matching_prefix(self, session: Session, tmp_path): + asset = _make_asset(session, "hash1") + dir1 = tmp_path / "dir1" + dir1.mkdir() + dir2 = tmp_path / "dir2" + dir2.mkdir() + + path1 = str(dir1 / "file.bin") + path2 = str(dir2 / "file.bin") + + _make_reference(session, asset, path1, name="file1", mtime_ns=100) + _make_reference(session, asset, path2, name="file2", mtime_ns=200) + session.commit() + + rows = get_references_for_prefixes(session, [str(dir1)]) + + assert len(rows) == 1 + assert rows[0].file_path == path1 + + def test_empty_prefixes_returns_empty(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/some/path.bin") + session.commit() + + rows = get_references_for_prefixes(session, []) + + assert rows == [] + + +class TestBulkSetNeedsVerify: + def test_sets_needs_verify_flag(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, "/path1.bin", needs_verify=False) + ref2 = _make_reference(session, asset, "/path2.bin", needs_verify=False) + session.commit() + + updated = bulk_update_needs_verify(session, [ref1.id, ref2.id], True) + session.commit() + + assert updated == 2 + session.refresh(ref1) + session.refresh(ref2) + assert ref1.needs_verify is True + assert ref2.needs_verify is True + + def test_empty_list_updates_nothing(self, session: Session): + updated = bulk_update_needs_verify(session, [], True) + assert updated == 0 + + +class TestDeleteReferencesByIds: + def test_deletes_references_by_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, "/path1.bin") + _make_reference(session, asset, "/path2.bin") + session.commit() + + deleted = delete_references_by_ids(session, [ref1.id]) + session.commit() + + assert deleted == 1 + assert session.query(AssetReference).count() == 1 + + def test_empty_list_deletes_nothing(self, session: Session): + deleted = delete_references_by_ids(session, []) + assert deleted == 0 + + +class TestDeleteOrphanedSeedAsset: + @pytest.mark.parametrize( + "create_asset,expected_deleted,expected_count", + [ + (True, True, 0), # Existing asset gets deleted + (False, False, 0), # Nonexistent returns False + ], + ids=["deletes_existing", "nonexistent_returns_false"], + ) + def test_delete_orphaned_seed_asset( + self, session: Session, create_asset, expected_deleted, expected_count + ): + asset_id = "nonexistent-id" + if create_asset: + asset = _make_asset(session, hash_val=None) + asset_id = asset.id + _make_reference(session, asset, "/test/path.bin", name="test") + session.commit() + + deleted = delete_orphaned_seed_asset(session, asset_id) + if create_asset: + session.commit() + + assert deleted is expected_deleted + assert session.query(Asset).count() == expected_count + + +class TestBulkInsertReferencesIgnoreConflicts: + def test_inserts_multiple_references(self, session: Session): + asset = _make_asset(session, "hash1") + now = get_utc_now() + rows = [ + { + "asset_id": asset.id, + "file_path": "/bulk1.bin", + "name": "bulk1", + "mtime_ns": 100, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "asset_id": asset.id, + "file_path": "/bulk2.bin", + "name": "bulk2", + "mtime_ns": 200, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + assert session.query(AssetReference).count() == 2 + + def test_ignores_conflicts(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/existing.bin", mtime_ns=100) + session.commit() + + now = get_utc_now() + rows = [ + { + "asset_id": asset.id, + "file_path": "/existing.bin", + "name": "existing", + "mtime_ns": 999, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "asset_id": asset.id, + "file_path": "/new.bin", + "name": "new", + "mtime_ns": 200, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + assert session.query(AssetReference).count() == 2 + existing = session.query(AssetReference).filter_by(file_path="/existing.bin").one() + assert existing.mtime_ns == 100 # Original value preserved + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_references_ignore_conflicts(session, []) + assert session.query(AssetReference).count() == 0 + + +class TestGetReferencesByPathsAndAssetIds: + def test_returns_matching_paths(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + + _make_reference(session, asset1, "/path1.bin") + _make_reference(session, asset2, "/path2.bin") + session.commit() + + path_to_asset = { + "/path1.bin": asset1.id, + "/path2.bin": asset2.id, + } + winners = get_references_by_paths_and_asset_ids(session, path_to_asset) + + assert winners == {"/path1.bin", "/path2.bin"} + + def test_excludes_non_matching_asset_ids(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + + _make_reference(session, asset1, "/path1.bin") + session.commit() + + # Path exists but with different asset_id + path_to_asset = {"/path1.bin": asset2.id} + winners = get_references_by_paths_and_asset_ids(session, path_to_asset) + + assert winners == set() + + def test_empty_dict_returns_empty(self, session: Session): + winners = get_references_by_paths_and_asset_ids(session, {}) + assert winners == set() diff --git a/tests-unit/assets_test/queries/test_metadata.py b/tests-unit/assets_test/queries/test_metadata.py new file mode 100644 index 000000000..6a545e819 --- /dev/null +++ b/tests-unit/assets_test/queries/test_metadata.py @@ -0,0 +1,184 @@ +"""Tests for metadata filtering logic in asset_reference queries.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta +from app.assets.database.queries import list_references_page +from app.assets.database.queries.asset_reference import convert_metadata_to_rows +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str) -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str, + metadata: dict | None = None, +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id="", + name=name, + asset_id=asset.id, + user_metadata=metadata, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + + if metadata: + for key, val in metadata.items(): + for row in convert_metadata_to_rows(key, val): + meta_row = AssetReferenceMeta( + asset_reference_id=ref.id, + key=row["key"], + ordinal=row.get("ordinal", 0), + val_str=row.get("val_str"), + val_num=row.get("val_num"), + val_bool=row.get("val_bool"), + val_json=row.get("val_json"), + ) + session.add(meta_row) + session.flush() + + return ref + + +class TestMetadataFilterByType: + """Table-driven tests for metadata filtering by different value types.""" + + @pytest.mark.parametrize( + "match_meta,nomatch_meta,filter_key,filter_val", + [ + # String matching + ({"category": "models"}, {"category": "images"}, "category", "models"), + # Integer matching + ({"epoch": 5}, {"epoch": 10}, "epoch", 5), + # Float matching + ({"score": 0.95}, {"score": 0.5}, "score", 0.95), + # Boolean True matching + ({"enabled": True}, {"enabled": False}, "enabled", True), + # Boolean False matching + ({"enabled": False}, {"enabled": True}, "enabled", False), + ], + ids=["string", "int", "float", "bool_true", "bool_false"], + ) + def test_filter_matches_correct_value( + self, session: Session, match_meta, nomatch_meta, filter_key, filter_val + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "match", match_meta) + _make_reference(session, asset, "nomatch", nomatch_meta) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={filter_key: filter_val} + ) + assert total == 1 + assert refs[0].name == "match" + + @pytest.mark.parametrize( + "stored_meta,filter_key,filter_val", + [ + # String no match + ({"category": "models"}, "category", "other"), + # Int no match + ({"epoch": 5}, "epoch", 99), + # Float no match + ({"score": 0.5}, "score", 0.99), + ], + ids=["string_no_match", "int_no_match", "float_no_match"], + ) + def test_filter_returns_empty_when_no_match( + self, session: Session, stored_meta, filter_key, filter_val + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "item", stored_meta) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={filter_key: filter_val} + ) + assert total == 0 + + +class TestMetadataFilterNull: + """Tests for null/missing key filtering.""" + + @pytest.mark.parametrize( + "match_name,match_meta,nomatch_name,nomatch_meta,filter_key", + [ + # Null matches missing key + ("missing_key", {}, "has_key", {"optional": "value"}, "optional"), + # Null matches explicit null + ("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"), + ], + ids=["missing_key", "explicit_null"], + ) + def test_null_filter_matches( + self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, match_name, match_meta) + _make_reference(session, asset, nomatch_name, nomatch_meta) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={filter_key: None}) + assert total == 1 + assert refs[0].name == match_name + + +class TestMetadataFilterList: + """Tests for list-based (OR) filtering.""" + + def test_filter_by_list_matches_any(self, session: Session): + """List values should match ANY of the values (OR).""" + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "cat_a", {"category": "a"}) + _make_reference(session, asset, "cat_b", {"category": "b"}) + _make_reference(session, asset, "cat_c", {"category": "c"}) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={"category": ["a", "b"]}) + assert total == 2 + names = {r.name for r in refs} + assert names == {"cat_a", "cat_b"} + + +class TestMetadataFilterMultipleKeys: + """Tests for multiple filter keys (AND semantics).""" + + def test_multiple_keys_must_all_match(self, session: Session): + """Multiple keys should ALL match (AND).""" + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "match", {"type": "model", "version": 2}) + _make_reference(session, asset, "wrong_type", {"type": "config", "version": 2}) + _make_reference(session, asset, "wrong_version", {"type": "model", "version": 1}) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={"type": "model", "version": 2} + ) + assert total == 1 + assert refs[0].name == "match" + + +class TestMetadataFilterEmptyDict: + """Tests for empty filter behavior.""" + + def test_empty_filter_returns_all(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "a", {"key": "val"}) + _make_reference(session, asset, "b", {}) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={}) + assert total == 2 diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py new file mode 100644 index 000000000..4ed99aa37 --- /dev/null +++ b/tests-unit/assets_test/queries/test_tags.py @@ -0,0 +1,366 @@ +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, AssetReferenceMeta, Tag +from app.assets.database.queries import ( + ensure_tags_exist, + get_reference_tags, + set_reference_tags, + add_tags_to_reference, + remove_tags_from_reference, + add_missing_tag_for_asset_id, + remove_missing_tag_for_asset_id, + list_tags_with_usage, + bulk_insert_tags_and_meta, +) +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None) -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestEnsureTagsExist: + def test_creates_new_tags(self, session: Session): + ensure_tags_exist(session, ["alpha", "beta"], tag_type="user") + session.commit() + + tags = session.query(Tag).all() + assert {t.name for t in tags} == {"alpha", "beta"} + + def test_is_idempotent(self, session: Session): + ensure_tags_exist(session, ["alpha"], tag_type="user") + ensure_tags_exist(session, ["alpha"], tag_type="user") + session.commit() + + assert session.query(Tag).count() == 1 + + def test_normalizes_tags(self, session: Session): + ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"]) + session.commit() + + tags = session.query(Tag).all() + assert {t.name for t in tags} == {"alpha", "beta"} + + def test_empty_list_is_noop(self, session: Session): + ensure_tags_exist(session, []) + session.commit() + assert session.query(Tag).count() == 0 + + def test_tag_type_is_set(self, session: Session): + ensure_tags_exist(session, ["system-tag"], tag_type="system") + session.commit() + + tag = session.query(Tag).filter_by(name="system-tag").one() + assert tag.tag_type == "system" + + +class TestGetReferenceTags: + def test_returns_empty_for_no_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + tags = get_reference_tags(session, reference_id=ref.id) + assert tags == [] + + def test_returns_tags_for_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + ensure_tags_exist(session, ["tag1", "tag2"]) + session.add_all([ + AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag1", origin="manual", added_at=get_utc_now()), + AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag2", origin="manual", added_at=get_utc_now()), + ]) + session.flush() + + tags = get_reference_tags(session, reference_id=ref.id) + assert set(tags) == {"tag1", "tag2"} + + +class TestSetReferenceTags: + def test_adds_new_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + result = set_reference_tags(session, reference_id=ref.id, tags=["a", "b"]) + session.commit() + + assert set(result.added) == {"a", "b"} + assert result.removed == [] + assert set(result.total) == {"a", "b"} + + def test_removes_old_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + set_reference_tags(session, reference_id=ref.id, tags=["a", "b", "c"]) + result = set_reference_tags(session, reference_id=ref.id, tags=["a"]) + session.commit() + + assert result.added == [] + assert set(result.removed) == {"b", "c"} + assert result.total == ["a"] + + def test_replaces_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + set_reference_tags(session, reference_id=ref.id, tags=["a", "b"]) + result = set_reference_tags(session, reference_id=ref.id, tags=["b", "c"]) + session.commit() + + assert result.added == ["c"] + assert result.removed == ["a"] + assert set(result.total) == {"b", "c"} + + +class TestAddTagsToReference: + def test_adds_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"]) + session.commit() + + assert set(result.added) == {"x", "y"} + assert result.already_present == [] + + def test_reports_already_present(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["x"]) + result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"]) + session.commit() + + assert result.added == ["y"] + assert result.already_present == ["x"] + + def test_raises_for_missing_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + add_tags_to_reference(session, reference_id="nonexistent", tags=["x"]) + + +class TestRemoveTagsFromReference: + def test_removes_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"]) + result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "b"]) + session.commit() + + assert set(result.removed) == {"a", "b"} + assert result.not_present == [] + assert result.total_tags == ["c"] + + def test_reports_not_present(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["a"]) + result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "x"]) + session.commit() + + assert result.removed == ["a"] + assert result.not_present == ["x"] + + def test_raises_for_missing_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + remove_tags_from_reference(session, reference_id="nonexistent", tags=["x"]) + + +class TestMissingTagFunctions: + def test_add_missing_tag_for_asset_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + + add_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert "missing" in tags + + def test_add_missing_tag_is_idempotent(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + + add_missing_tag_for_asset_id(session, asset_id=asset.id) + add_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="missing").all() + assert len(links) == 1 + + def test_remove_missing_tag_for_asset_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + add_missing_tag_for_asset_id(session, asset_id=asset.id) + + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert "missing" not in tags + + +class TestListTagsWithUsage: + def test_returns_tags_with_counts(self, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags_with_usage(session) + + tag_dict = {name: count for name, _, count in rows} + assert tag_dict["used"] == 1 + assert tag_dict["unused"] == 0 + assert total == 2 + + def test_exclude_zero_counts(self, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags_with_usage(session, include_zero=False) + + tag_names = {name for name, _, _ in rows} + assert "used" in tag_names + assert "unused" not in tag_names + + def test_prefix_filter(self, session: Session): + ensure_tags_exist(session, ["alpha", "beta", "alphabet"]) + session.commit() + + rows, total = list_tags_with_usage(session, prefix="alph") + + tag_names = {name for name, _, _ in rows} + assert tag_names == {"alpha", "alphabet"} + + def test_order_by_name(self, session: Session): + ensure_tags_exist(session, ["zebra", "alpha", "middle"]) + session.commit() + + rows, _ = list_tags_with_usage(session, order="name_asc") + + names = [name for name, _, _ in rows] + assert names == ["alpha", "middle", "zebra"] + + def test_owner_visibility(self, session: Session): + ensure_tags_exist(session, ["shared-tag", "owner-tag"]) + + asset = _make_asset(session, "hash1") + shared_ref = _make_reference(session, asset, name="shared", owner_id="") + owner_ref = _make_reference(session, asset, name="owned", owner_id="user1") + + add_tags_to_reference(session, reference_id=shared_ref.id, tags=["shared-tag"]) + add_tags_to_reference(session, reference_id=owner_ref.id, tags=["owner-tag"]) + session.commit() + + # Empty owner sees only shared + rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False) + tag_dict = {name: count for name, _, count in rows} + assert tag_dict.get("shared-tag", 0) == 1 + assert tag_dict.get("owner-tag", 0) == 0 + + # User1 sees both + rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False) + tag_dict = {name: count for name, _, count in rows} + assert tag_dict.get("shared-tag", 0) == 1 + assert tag_dict.get("owner-tag", 0) == 1 + + +class TestBulkInsertTagsAndMeta: + def test_inserts_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"]) + session.commit() + + now = get_utc_now() + tag_rows = [ + {"asset_reference_id": ref.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now}, + {"asset_reference_id": ref.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now}, + ] + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[]) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert set(tags) == {"bulk-tag1", "bulk-tag2"} + + def test_inserts_meta(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + meta_rows = [ + { + "asset_reference_id": ref.id, + "key": "meta-key", + "ordinal": 0, + "val_str": "meta-value", + "val_num": None, + "val_bool": None, + "val_json": None, + }, + ] + bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows) + session.commit() + + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "meta-key" + assert meta[0].val_str == "meta-value" + + def test_ignores_conflicts(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["existing-tag"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["existing-tag"]) + session.commit() + + now = get_utc_now() + tag_rows = [ + {"asset_reference_id": ref.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now}, + ] + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[]) + session.commit() + + # Should still have only one tag link + links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="existing-tag").all() + assert len(links) == 1 + # Origin should be original, not overwritten + assert links[0].origin == "manual" + + def test_empty_lists_is_noop(self, session: Session): + bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[]) + assert session.query(AssetReferenceTag).count() == 0 + assert session.query(AssetReferenceMeta).count() == 0 diff --git a/tests-unit/assets_test/services/__init__.py b/tests-unit/assets_test/services/__init__.py new file mode 100644 index 000000000..d0213422e --- /dev/null +++ b/tests-unit/assets_test/services/__init__.py @@ -0,0 +1 @@ +# Service layer tests diff --git a/tests-unit/assets_test/services/conftest.py b/tests-unit/assets_test/services/conftest.py new file mode 100644 index 000000000..31c763d48 --- /dev/null +++ b/tests-unit/assets_test/services/conftest.py @@ -0,0 +1,54 @@ +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import Base + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(): + """Override parent autouse fixture - service unit tests don't need server cleanup.""" + yield + + +@pytest.fixture +def db_engine(): + """In-memory SQLite engine for fast unit tests.""" + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def session(db_engine): + """Session fixture for tests that need direct DB access.""" + with Session(db_engine) as sess: + yield sess + + +@pytest.fixture +def mock_create_session(db_engine): + """Patch create_session to use our in-memory database.""" + from contextlib import contextmanager + from sqlalchemy.orm import Session as SASession + + @contextmanager + def _create_session(): + with SASession(db_engine) as sess: + yield sess + + with patch("app.assets.services.ingest.create_session", _create_session), \ + patch("app.assets.services.asset_management.create_session", _create_session), \ + patch("app.assets.services.tagging.create_session", _create_session): + yield _create_session + + +@pytest.fixture +def temp_dir(): + """Temporary directory for file operations.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) diff --git a/tests-unit/assets_test/services/test_asset_management.py b/tests-unit/assets_test/services/test_asset_management.py new file mode 100644 index 000000000..101ef7292 --- /dev/null +++ b/tests-unit/assets_test/services/test_asset_management.py @@ -0,0 +1,268 @@ +"""Tests for asset_management services.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference +from app.assets.helpers import get_utc_now +from app.assets.services import ( + get_asset_detail, + update_asset_metadata, + delete_asset_reference, + set_asset_preview, +) + + +def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestGetAssetDetail: + def test_returns_none_for_nonexistent(self, mock_create_session): + result = get_asset_detail(reference_id="nonexistent") + assert result is None + + def test_returns_asset_with_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["alpha", "beta"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"]) + session.commit() + + result = get_asset_detail(reference_id=ref.id) + + assert result is not None + assert result.ref.id == ref.id + assert result.asset.hash == asset.hash + assert set(result.tags) == {"alpha", "beta"} + + def test_respects_owner_visibility(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + # Wrong owner cannot see + result = get_asset_detail(reference_id=ref.id, owner_id="user2") + assert result is None + + # Correct owner can see + result = get_asset_detail(reference_id=ref.id, owner_id="user1") + assert result is not None + + +class TestUpdateAssetMetadata: + def test_updates_name(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, name="old_name.bin") + ref_id = ref.id + session.commit() + + update_asset_metadata( + reference_id=ref_id, + name="new_name.bin", + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.name == "new_name.bin" + + def test_updates_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["old"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["old"]) + session.commit() + + result = update_asset_metadata( + reference_id=ref.id, + tags=["new1", "new2"], + ) + + assert set(result.tags) == {"new1", "new2"} + assert "old" not in result.tags + + def test_updates_user_metadata(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ref_id = ref.id + session.commit() + + update_asset_metadata( + reference_id=ref_id, + user_metadata={"key": "value", "num": 42}, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.user_metadata["key"] == "value" + assert updated_ref.user_metadata["num"] == 42 + + def test_raises_for_nonexistent(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + update_asset_metadata(reference_id="nonexistent", name="fail") + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + update_asset_metadata( + reference_id=ref.id, + name="new", + owner_id="user2", + ) + + +class TestDeleteAssetReference: + def test_soft_deletes_reference(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ref_id = ref.id + session.commit() + + result = delete_asset_reference( + reference_id=ref_id, + owner_id="", + delete_content_if_orphan=False, + ) + + assert result is True + # Row still exists but is marked as soft-deleted + session.expire_all() + row = session.get(AssetReference, ref_id) + assert row is not None + assert row.deleted_at is not None + + def test_returns_false_for_nonexistent(self, mock_create_session): + result = delete_asset_reference( + reference_id="nonexistent", + owner_id="", + ) + assert result is False + + def test_returns_false_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + ref_id = ref.id + session.commit() + + result = delete_asset_reference( + reference_id=ref_id, + owner_id="user2", + ) + + assert result is False + assert session.get(AssetReference, ref_id) is not None + + def test_keeps_asset_if_other_references_exist(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref1 = _make_reference(session, asset, name="ref1") + _make_reference(session, asset, name="ref2") # Second ref keeps asset alive + asset_id = asset.id + session.commit() + + delete_asset_reference( + reference_id=ref1.id, + owner_id="", + delete_content_if_orphan=True, + ) + + # Asset should still exist + assert session.get(Asset, asset_id) is not None + + def test_deletes_orphaned_asset(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + asset_id = asset.id + ref_id = ref.id + session.commit() + + delete_asset_reference( + reference_id=ref_id, + owner_id="", + delete_content_if_orphan=True, + ) + + # Both ref and asset should be gone + assert session.get(AssetReference, ref_id) is None + assert session.get(Asset, asset_id) is None + + +class TestSetAssetPreview: + def test_sets_preview(self, mock_create_session, session: Session): + asset = _make_asset(session, hash_val="blake3:main") + preview_asset = _make_asset(session, hash_val="blake3:preview") + ref = _make_reference(session, asset) + ref_id = ref.id + preview_id = preview_asset.id + session.commit() + + set_asset_preview( + reference_id=ref_id, + preview_asset_id=preview_id, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.preview_id == preview_id + + def test_clears_preview(self, mock_create_session, session: Session): + asset = _make_asset(session) + preview_asset = _make_asset(session, hash_val="blake3:preview") + ref = _make_reference(session, asset) + ref.preview_id = preview_asset.id + ref_id = ref.id + session.commit() + + set_asset_preview( + reference_id=ref_id, + preview_asset_id=None, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.preview_id is None + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + set_asset_preview(reference_id="nonexistent") + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + set_asset_preview( + reference_id=ref.id, + preview_asset_id=None, + owner_id="user2", + ) diff --git a/tests-unit/assets_test/services/test_bulk_ingest.py b/tests-unit/assets_test/services/test_bulk_ingest.py new file mode 100644 index 000000000..26e22a01d --- /dev/null +++ b/tests-unit/assets_test/services/test_bulk_ingest.py @@ -0,0 +1,137 @@ +"""Tests for bulk ingest services.""" + +from pathlib import Path + +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets + + +class TestBatchInsertSeedAssets: + def test_populates_mime_type_for_model_files(self, session: Session, temp_dir: Path): + """Verify mime_type is stored in the Asset table for model files.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"fake safetensors content") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 24, + "mtime_ns": 1234567890000000000, + "info_name": "Test Model", + "tags": ["models"], + "fname": "model.safetensors", + "metadata": None, + "hash": None, + "mime_type": "application/safetensors", + } + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + + # Verify Asset has mime_type populated + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].mime_type == "application/safetensors" + + def test_mime_type_none_when_not_provided(self, session: Session, temp_dir: Path): + """Verify mime_type is None when not provided in spec.""" + file_path = temp_dir / "unknown.bin" + file_path.write_bytes(b"binary data") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 11, + "mtime_ns": 1234567890000000000, + "info_name": "Unknown File", + "tags": [], + "fname": "unknown.bin", + "metadata": None, + "hash": None, + "mime_type": None, + } + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].mime_type is None + + def test_various_model_mime_types(self, session: Session, temp_dir: Path): + """Verify various model file types get correct mime_type.""" + test_cases = [ + ("model.safetensors", "application/safetensors"), + ("model.pt", "application/pytorch"), + ("model.ckpt", "application/pickle"), + ("model.gguf", "application/gguf"), + ] + + specs: list[SeedAssetSpec] = [] + for filename, mime_type in test_cases: + file_path = temp_dir / filename + file_path.write_bytes(b"content") + specs.append( + { + "abs_path": str(file_path), + "size_bytes": 7, + "mtime_ns": 1234567890000000000, + "info_name": filename, + "tags": [], + "fname": filename, + "metadata": None, + "hash": None, + "mime_type": mime_type, + } + ) + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == len(test_cases) + + for filename, expected_mime in test_cases: + ref = session.query(AssetReference).filter_by(name=filename).first() + assert ref is not None + asset = session.query(Asset).filter_by(id=ref.asset_id).first() + assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}" + + +class TestMetadataExtraction: + def test_extracts_mime_type_for_model_files(self, temp_dir: Path): + """Verify metadata extraction returns correct mime_type for model files.""" + from app.assets.services.metadata_extract import extract_file_metadata + + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"fake safetensors content") + + meta = extract_file_metadata(str(file_path)) + + assert meta.content_type == "application/safetensors" + + def test_mime_type_for_various_model_formats(self, temp_dir: Path): + """Verify various model file types get correct mime_type from metadata.""" + from app.assets.services.metadata_extract import extract_file_metadata + + test_cases = [ + ("model.safetensors", "application/safetensors"), + ("model.sft", "application/safetensors"), + ("model.pt", "application/pytorch"), + ("model.pth", "application/pytorch"), + ("model.ckpt", "application/pickle"), + ("model.pkl", "application/pickle"), + ("model.gguf", "application/gguf"), + ] + + for filename, expected_mime in test_cases: + file_path = temp_dir / filename + file_path.write_bytes(b"content") + + meta = extract_file_metadata(str(file_path)) + + assert meta.content_type == expected_mime, f"Expected {expected_mime} for {filename}, got {meta.content_type}" diff --git a/tests-unit/assets_test/services/test_enrich.py b/tests-unit/assets_test/services/test_enrich.py new file mode 100644 index 000000000..2bd79a01a --- /dev/null +++ b/tests-unit/assets_test/services/test_enrich.py @@ -0,0 +1,207 @@ +"""Tests for asset enrichment (mime_type and hash population).""" +from pathlib import Path + +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.scanner import ( + ENRICHMENT_HASHED, + ENRICHMENT_METADATA, + ENRICHMENT_STUB, + enrich_asset, +) + + +def _create_stub_asset( + session: Session, + file_path: str, + asset_id: str = "test-asset-id", + reference_id: str = "test-ref-id", + name: str | None = None, +) -> tuple[Asset, AssetReference]: + """Create a stub asset with reference for testing enrichment.""" + asset = Asset( + id=asset_id, + hash=None, + size_bytes=100, + mime_type=None, + ) + session.add(asset) + session.flush() + + ref = AssetReference( + id=reference_id, + asset_id=asset_id, + name=name or f"test-asset-{asset_id}", + owner_id="system", + file_path=file_path, + mtime_ns=1234567890000000000, + enrichment_level=ENRICHMENT_STUB, + ) + session.add(ref) + session.flush() + + return asset, ref + + +class TestEnrichAsset: + def test_extracts_mime_type_and_updates_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify mime_type is written to the Asset table during enrichment.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"\x00" * 100) + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-1", "ref-1" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=False, + ) + + assert new_level == ENRICHMENT_METADATA + + session.expire_all() + updated_asset = session.get(Asset, "asset-1") + assert updated_asset is not None + assert updated_asset.mime_type == "application/safetensors" + + def test_computes_hash_and_updates_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify hash is written to the Asset table during enrichment.""" + file_path = temp_dir / "data.bin" + file_path.write_bytes(b"test content for hashing") + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-2", "ref-2" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + assert new_level == ENRICHMENT_HASHED + + session.expire_all() + updated_asset = session.get(Asset, "asset-2") + assert updated_asset is not None + assert updated_asset.hash is not None + assert updated_asset.hash.startswith("blake3:") + + def test_enrichment_updates_both_mime_and_hash( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify both mime_type and hash are set when full enrichment runs.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"\x00" * 50) + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-3", "ref-3" + ) + session.commit() + + enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + session.expire_all() + updated_asset = session.get(Asset, "asset-3") + assert updated_asset is not None + assert updated_asset.mime_type == "application/safetensors" + assert updated_asset.hash is not None + assert updated_asset.hash.startswith("blake3:") + + def test_missing_file_returns_stub_level( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify missing files don't cause errors and return STUB level.""" + file_path = temp_dir / "nonexistent.bin" + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-4", "ref-4" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + assert new_level == ENRICHMENT_STUB + + session.expire_all() + updated_asset = session.get(Asset, "asset-4") + assert updated_asset.mime_type is None + assert updated_asset.hash is None + + def test_duplicate_hash_merges_into_existing_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify duplicate files merge into existing asset instead of failing.""" + file_path_1 = temp_dir / "file1.bin" + file_path_2 = temp_dir / "file2.bin" + content = b"identical content" + file_path_1.write_bytes(content) + file_path_2.write_bytes(content) + + asset1, ref1 = _create_stub_asset( + session, str(file_path_1), "asset-dup-1", "ref-dup-1" + ) + asset2, ref2 = _create_stub_asset( + session, str(file_path_2), "asset-dup-2", "ref-dup-2" + ) + session.commit() + + enrich_asset( + session, + file_path=str(file_path_1), + reference_id=ref1.id, + asset_id=asset1.id, + extract_metadata=True, + compute_hash=True, + ) + + enrich_asset( + session, + file_path=str(file_path_2), + reference_id=ref2.id, + asset_id=asset2.id, + extract_metadata=True, + compute_hash=True, + ) + + session.expire_all() + + updated_asset1 = session.get(Asset, "asset-dup-1") + assert updated_asset1 is not None + assert updated_asset1.hash is not None + + updated_asset2 = session.get(Asset, "asset-dup-2") + assert updated_asset2 is None + + updated_ref2 = session.get(AssetReference, "ref-dup-2") + assert updated_ref2 is not None + assert updated_ref2.asset_id == "asset-dup-1" diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py new file mode 100644 index 000000000..367bc7721 --- /dev/null +++ b/tests-unit/assets_test/services/test_ingest.py @@ -0,0 +1,229 @@ +"""Tests for ingest services.""" +from pathlib import Path + +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, Tag +from app.assets.database.queries import get_reference_tags +from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset + + +class TestIngestFileFromPath: + def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "test_file.bin" + file_path.write_bytes(b"test content") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:abc123", + size_bytes=12, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + ) + + assert result.asset_created is True + assert result.ref_created is True + assert result.reference_id is not None + + # Verify DB state + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].hash == "blake3:abc123" + + refs = session.query(AssetReference).all() + assert len(refs) == 1 + assert refs[0].file_path == str(file_path) + + def test_creates_reference_when_name_provided(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"model data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:def456", + size_bytes=10, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + info_name="My Model", + owner_id="user1", + ) + + assert result.asset_created is True + assert result.reference_id is not None + + ref = session.query(AssetReference).first() + assert ref is not None + assert ref.name == "My Model" + assert ref.owner_id == "user1" + + def test_creates_tags_when_provided(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "tagged.bin" + file_path.write_bytes(b"data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:ghi789", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="Tagged Asset", + tags=["models", "checkpoints"], + ) + + assert result.reference_id is not None + + # Verify tags were created and linked + tags = session.query(Tag).all() + tag_names = {t.name for t in tags} + assert "models" in tag_names + assert "checkpoints" in tag_names + + ref_tags = get_reference_tags(session, reference_id=result.reference_id) + assert set(ref_tags) == {"models", "checkpoints"} + + def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "dup.bin" + file_path.write_bytes(b"content") + + # First ingest + r1 = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:repeat", + size_bytes=7, + mtime_ns=1234567890000000000, + ) + assert r1.asset_created is True + + # Second ingest with same hash - should update, not create + r2 = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:repeat", + size_bytes=7, + mtime_ns=1234567890000000001, # different mtime + ) + assert r2.asset_created is False + assert r2.ref_created is False + assert r2.ref_updated is True + + # Still only one asset + assets = session.query(Asset).all() + assert len(assets) == 1 + + def test_validates_preview_id(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "with_preview.bin" + file_path.write_bytes(b"data") + + # Create a preview asset first + preview_asset = Asset(hash="blake3:preview", size_bytes=100) + session.add(preview_asset) + session.commit() + preview_id = preview_asset.id + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:main", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="With Preview", + preview_id=preview_id, + ) + + assert result.reference_id is not None + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.preview_id == preview_id + + def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "bad_preview.bin" + file_path.write_bytes(b"data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:badpreview", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="Bad Preview", + preview_id="nonexistent-uuid", + ) + + assert result.reference_id is not None + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.preview_id is None + + +class TestRegisterExistingAsset: + def test_creates_reference_for_existing_asset(self, mock_create_session, session: Session): + # Create existing asset + asset = Asset(hash="blake3:existing", size_bytes=1024, mime_type="image/png") + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:existing", + name="Registered Asset", + user_metadata={"key": "value"}, + tags=["models"], + ) + + assert result.created is True + assert "models" in result.tags + + # Verify by re-fetching from DB + session.expire_all() + refs = session.query(AssetReference).filter_by(name="Registered Asset").all() + assert len(refs) == 1 + + def test_creates_new_reference_even_with_same_name(self, mock_create_session, session: Session): + # Create asset and reference + asset = Asset(hash="blake3:withref", size_bytes=512) + session.add(asset) + session.flush() + + from app.assets.helpers import get_utc_now + ref = AssetReference( + owner_id="", + name="Existing Ref", + asset_id=asset.id, + created_at=get_utc_now(), + updated_at=get_utc_now(), + last_access_time=get_utc_now(), + ) + session.add(ref) + session.flush() + ref_id = ref.id + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:withref", + name="Existing Ref", + owner_id="", + ) + + # Multiple files with same name are allowed + assert result.created is True + + # Verify two AssetReferences exist for this name + session.expire_all() + refs = session.query(AssetReference).filter_by(name="Existing Ref").all() + assert len(refs) == 2 + assert ref_id in [r.id for r in refs] + + def test_raises_for_nonexistent_hash(self, mock_create_session): + with pytest.raises(ValueError, match="No asset with hash"): + _register_existing_asset( + asset_hash="blake3:doesnotexist", + name="Fail", + ) + + def test_applies_tags_to_new_reference(self, mock_create_session, session: Session): + asset = Asset(hash="blake3:tagged", size_bytes=256) + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:tagged", + name="Tagged Ref", + tags=["alpha", "beta"], + ) + + assert result.created is True + assert set(result.tags) == {"alpha", "beta"} diff --git a/tests-unit/assets_test/services/test_tagging.py b/tests-unit/assets_test/services/test_tagging.py new file mode 100644 index 000000000..ab69e5dc1 --- /dev/null +++ b/tests-unit/assets_test/services/test_tagging.py @@ -0,0 +1,197 @@ +"""Tests for tagging services.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference +from app.assets.helpers import get_utc_now +from app.assets.services import apply_tags, remove_tags, list_tags + + +def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestApplyTags: + def test_adds_new_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + session.commit() + + result = apply_tags( + reference_id=ref.id, + tags=["alpha", "beta"], + ) + + assert set(result.added) == {"alpha", "beta"} + assert result.already_present == [] + assert set(result.total_tags) == {"alpha", "beta"} + + def test_reports_already_present(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["existing"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["existing"]) + session.commit() + + result = apply_tags( + reference_id=ref.id, + tags=["existing", "new"], + ) + + assert result.added == ["new"] + assert result.already_present == ["existing"] + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + apply_tags(reference_id="nonexistent", tags=["x"]) + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + apply_tags( + reference_id=ref.id, + tags=["new"], + owner_id="user2", + ) + + +class TestRemoveTags: + def test_removes_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["a", "b", "c"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"]) + session.commit() + + result = remove_tags( + reference_id=ref.id, + tags=["a", "b"], + ) + + assert set(result.removed) == {"a", "b"} + assert result.not_present == [] + assert result.total_tags == ["c"] + + def test_reports_not_present(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["present"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["present"]) + session.commit() + + result = remove_tags( + reference_id=ref.id, + tags=["present", "absent"], + ) + + assert result.removed == ["present"] + assert result.not_present == ["absent"] + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + remove_tags(reference_id="nonexistent", tags=["x"]) + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + remove_tags( + reference_id=ref.id, + tags=["x"], + owner_id="user2", + ) + + +class TestListTags: + def test_returns_tags_with_counts(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + asset = _make_asset(session) + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags() + + tag_dict = {name: count for name, _, count in rows} + assert tag_dict["used"] == 1 + assert tag_dict["unused"] == 0 + assert total == 2 + + def test_excludes_zero_counts(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + asset = _make_asset(session) + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags(include_zero=False) + + tag_names = {name for name, _, _ in rows} + assert "used" in tag_names + assert "unused" not in tag_names + + def test_prefix_filter(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["alpha", "beta", "alphabet"]) + session.commit() + + rows, _ = list_tags(prefix="alph") + + tag_names = {name for name, _, _ in rows} + assert tag_names == {"alpha", "alphabet"} + + def test_order_by_name(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["zebra", "alpha", "middle"]) + session.commit() + + rows, _ = list_tags(order="name_asc") + + names = [name for name, _, _ in rows] + assert names == ["alpha", "middle", "zebra"] + + def test_pagination(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["a", "b", "c", "d", "e"]) + session.commit() + + rows, total = list_tags(limit=2, offset=1, order="name_asc") + + assert total == 5 + assert len(rows) == 2 + names = [name for name, _, _ in rows] + assert names == ["b", "c"] + + def test_clamps_limit(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["a"]) + session.commit() + + # Service should clamp limit to max 1000 + rows, _ = list_tags(limit=2000) + assert len(rows) <= 1000 diff --git a/tests-unit/assets_test/test_assets_missing_sync.py b/tests-unit/assets_test/test_assets_missing_sync.py index 78fa7b404..47dc130cb 100644 --- a/tests-unit/assets_test/test_assets_missing_sync.py +++ b/tests-unit/assets_test/test_assets_missing_sync.py @@ -4,7 +4,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets diff --git a/tests-unit/assets_test/test_crud.py b/tests-unit/assets_test/test_crud.py index d2b69f475..07310223e 100644 --- a/tests-unit/assets_test/test_crud.py +++ b/tests-unit/assets_test/test_crud.py @@ -4,7 +4,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets def test_create_from_hash_success( @@ -24,11 +24,11 @@ def test_create_from_hash_success( assert b1["created_new"] is False aid = b1["id"] - # Calling again with the same name should return the same AssetInfo id + # Calling again with the same name creates a new AssetReference (duplicates allowed) r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) b2 = r2.json() assert r2.status_code == 201, b2 - assert b2["id"] == aid + assert b2["id"] != aid # new reference, not the same one def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict): @@ -42,8 +42,8 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse assert "user_metadata" in detail assert "filename" in detail["user_metadata"] - # DELETE - rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + # DELETE (hard delete to also remove underlying asset and file) + rd = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) assert rd.status_code == 204 # GET again -> 404 @@ -53,6 +53,35 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse assert body["error"]["code"] == "ASSET_NOT_FOUND" +def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + asset_hash = seeded_asset["asset_hash"] + + # Soft-delete (default, no delete_content param) + rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + assert rd.status_code == 204 + + # GET by reference ID -> 404 (soft-deleted references are hidden) + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + assert rg.status_code == 404 + + # Asset identity is preserved (underlying content still exists) + rh = http.head(f"{api_base}/api/assets/hash/{asset_hash}", timeout=120) + assert rh.status_code == 200 + + # Soft-deleted reference should not appear in listings + rl = http.get( + f"{api_base}/api/assets", + params={"include_tags": "unit-tests", "limit": "500"}, + timeout=120, + ) + ids = [a["id"] for a in rl.json().get("assets", [])] + assert aid not in ids + + # Clean up: hard-delete the soft-deleted reference and orphaned asset + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) + + def test_delete_upon_reference_count( http: requests.Session, api_base: str, seeded_asset: dict ): @@ -70,21 +99,32 @@ def test_delete_upon_reference_count( assert copy["asset_hash"] == src_hash assert copy["created_new"] is False - # Delete original reference -> asset identity must remain + # Soft-delete original reference (default) -> asset identity must remain aid1 = seeded_asset["id"] rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120) assert rd1.status_code == 204 rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) - assert rh1.status_code == 200 # identity still present + assert rh1.status_code == 200 # identity still present (second ref exists) - # Delete the last reference with default semantics -> identity and cached files removed + # Soft-delete the last reference -> asset identity preserved (no hard delete) aid2 = copy["id"] rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120) assert rd2.status_code == 204 rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) - assert rh2.status_code == 404 # orphan content removed + assert rh2.status_code == 200 # asset identity preserved (soft delete) + + # Re-associate via from-hash, then hard-delete -> orphan content removed + r3 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + assert r3.status_code == 201, r3.json() + aid3 = r3.json()["id"] + + rd3 = http.delete(f"{api_base}/api/assets/{aid3}?delete_content=true", timeout=120) + assert rd3.status_code == 204 + + rh3 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) + assert rh3.status_code == 404 # orphan content removed def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict): @@ -126,42 +166,52 @@ def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api assert body == b"" -def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str): - bogus = str(uuid.uuid4()) - r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120) +@pytest.mark.parametrize( + "method,endpoint_template,payload,expected_status,error_code", + [ + # Delete nonexistent asset + ("delete", "/api/assets/{uuid}", None, 404, "ASSET_NOT_FOUND"), + # Bad hash algorithm in from-hash + ( + "post", + "/api/assets/from-hash", + {"hash": "sha256:" + "0" * 64, "name": "x.bin", "tags": ["models", "checkpoints", "unit-tests"]}, + 400, + "INVALID_BODY", + ), + # Get with bad UUID format + ("get", "/api/assets/not-a-uuid", None, 404, None), + # Get content with bad UUID format + ("get", "/api/assets/not-a-uuid/content", None, 404, None), + ], + ids=["delete_nonexistent", "bad_hash_algorithm", "get_bad_uuid", "content_bad_uuid"], +) +def test_error_responses( + http: requests.Session, api_base: str, method, endpoint_template, payload, expected_status, error_code +): + # Replace {uuid} placeholder with a random UUID for delete test + endpoint = endpoint_template.replace("{uuid}", str(uuid.uuid4())) + url = f"{api_base}{endpoint}" + + if method == "get": + r = http.get(url, timeout=120) + elif method == "post": + r = http.post(url, json=payload, timeout=120) + elif method == "delete": + r = http.delete(url, timeout=120) + + assert r.status_code == expected_status + if error_code: + body = r.json() + assert body["error"]["code"] == error_code + + +def test_create_from_hash_invalid_json(http: requests.Session, api_base: str): + """Invalid JSON body requires special handling (data= instead of json=).""" + r = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) body = r.json() - assert r.status_code == 404 - assert body["error"]["code"] == "ASSET_NOT_FOUND" - - -def test_create_from_hash_invalids(http: requests.Session, api_base: str): - # Bad hash algorithm - bad = { - "hash": "sha256:" + "0" * 64, - "name": "x.bin", - "tags": ["models", "checkpoints", "unit-tests"], - } - r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_BODY" - - # Invalid JSON body - r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_JSON" - - -def test_get_update_download_bad_ids(http: requests.Session, api_base: str): - # All endpoints should be not found, as we UUID regex directly in the route definition. - bad_id = "not-a-uuid" - - r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120) - assert r1.status_code == 404 - - r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120) - assert r3.status_code == 404 + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_JSON" def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict): diff --git a/tests-unit/assets_test/test_downloads.py b/tests-unit/assets_test/test_downloads.py index cdebf9082..672ba9728 100644 --- a/tests-unit/assets_test/test_downloads.py +++ b/tests-unit/assets_test/test_downloads.py @@ -6,7 +6,7 @@ from typing import Optional import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict): @@ -117,7 +117,7 @@ def test_download_missing_file_returns_404( assert body["error"]["code"] == "FILE_NOT_FOUND" finally: # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. - dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + dr = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) dr.content diff --git a/tests-unit/assets_test/test_file_utils.py b/tests-unit/assets_test/test_file_utils.py new file mode 100644 index 000000000..e3591d49b --- /dev/null +++ b/tests-unit/assets_test/test_file_utils.py @@ -0,0 +1,121 @@ +import os +import sys + +import pytest + +from app.assets.services.file_utils import is_visible, list_files_recursively + + +class TestIsVisible: + def test_visible_file(self): + assert is_visible("file.txt") is True + + def test_hidden_file(self): + assert is_visible(".hidden") is False + + def test_hidden_directory(self): + assert is_visible(".git") is False + + def test_visible_directory(self): + assert is_visible("src") is True + + def test_dotdot_is_hidden(self): + assert is_visible("..") is False + + def test_dot_is_hidden(self): + assert is_visible(".") is False + + +class TestListFilesRecursively: + def test_skips_hidden_files(self, tmp_path): + (tmp_path / "visible.txt").write_text("a") + (tmp_path / ".hidden").write_text("b") + + result = list_files_recursively(str(tmp_path)) + + assert len(result) == 1 + assert result[0].endswith("visible.txt") + + def test_skips_hidden_directories(self, tmp_path): + hidden_dir = tmp_path / ".hidden_dir" + hidden_dir.mkdir() + (hidden_dir / "file.txt").write_text("a") + + visible_dir = tmp_path / "visible_dir" + visible_dir.mkdir() + (visible_dir / "file.txt").write_text("b") + + result = list_files_recursively(str(tmp_path)) + + assert len(result) == 1 + assert "visible_dir" in result[0] + assert ".hidden_dir" not in result[0] + + def test_empty_directory(self, tmp_path): + result = list_files_recursively(str(tmp_path)) + assert result == [] + + def test_nonexistent_directory(self, tmp_path): + result = list_files_recursively(str(tmp_path / "nonexistent")) + assert result == [] + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_follows_symlinked_directories(self, tmp_path): + target = tmp_path / "real_dir" + target.mkdir() + (target / "model.safetensors").write_text("data") + + root = tmp_path / "root" + root.mkdir() + (root / "link").symlink_to(target) + + result = list_files_recursively(str(root)) + + assert len(result) == 1 + assert result[0].endswith("model.safetensors") + assert "link" in result[0] + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_follows_symlinked_files(self, tmp_path): + real_file = tmp_path / "real.txt" + real_file.write_text("content") + + root = tmp_path / "root" + root.mkdir() + (root / "link.txt").symlink_to(real_file) + + result = list_files_recursively(str(root)) + + assert len(result) == 1 + assert result[0].endswith("link.txt") + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_circular_symlinks_do_not_loop(self, tmp_path): + dir_a = tmp_path / "a" + dir_a.mkdir() + (dir_a / "file.txt").write_text("a") + # a/b -> a (circular) + (dir_a / "b").symlink_to(dir_a) + + result = list_files_recursively(str(dir_a)) + + assert len(result) == 1 + assert result[0].endswith("file.txt") + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_mutual_circular_symlinks(self, tmp_path): + dir_a = tmp_path / "a" + dir_b = tmp_path / "b" + dir_a.mkdir() + dir_b.mkdir() + (dir_a / "file_a.txt").write_text("a") + (dir_b / "file_b.txt").write_text("b") + # a/link_b -> b and b/link_a -> a + (dir_a / "link_b").symlink_to(dir_b) + (dir_b / "link_a").symlink_to(dir_a) + + result = list_files_recursively(str(dir_a)) + basenames = sorted(os.path.basename(p) for p in result) + + assert "file_a.txt" in basenames + assert "file_b.txt" in basenames diff --git a/tests-unit/assets_test/test_list_filter.py b/tests-unit/assets_test/test_list_filter.py index 82e109832..dcb7a73ca 100644 --- a/tests-unit/assets_test/test_list_filter.py +++ b/tests-unit/assets_test/test_list_filter.py @@ -1,6 +1,7 @@ import time import uuid +import pytest import requests @@ -283,30 +284,21 @@ def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asse assert b2["has_more"] is False -def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base): - r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_QUERY" - - r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_QUERY" - - -def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str): - # limit too small - r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_QUERY" - - # bad metadata JSON - r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_QUERY" +@pytest.mark.parametrize( + "params,error_code", + [ + ({"offset": "-1"}, "INVALID_QUERY"), + ({"limit": "abc"}, "INVALID_QUERY"), + ({"limit": "0"}, "INVALID_QUERY"), + ({"metadata_filter": "{not json"}, "INVALID_QUERY"), + ], + ids=["negative_offset", "non_int_limit", "zero_limit", "invalid_metadata_json"], +) +def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str, params, error_code): + r = http.get(api_base + "/api/assets", params=params, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == error_code def test_list_assets_name_contains_literal_underscore( diff --git a/tests-unit/assets_test/test_prune_orphaned_assets.py b/tests-unit/assets_test/test_prune_orphaned_assets.py index f602e5a77..1fbd4d4e2 100644 --- a/tests-unit/assets_test/test_prune_orphaned_assets.py +++ b/tests-unit/assets_test/test_prune_orphaned_assets.py @@ -3,7 +3,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets @pytest.fixture diff --git a/tests-unit/assets_test/test_sync_references.py b/tests-unit/assets_test/test_sync_references.py new file mode 100644 index 000000000..94cc255bc --- /dev/null +++ b/tests-unit/assets_test/test_sync_references.py @@ -0,0 +1,482 @@ +"""Tests for sync_references_with_filesystem in scanner.py.""" + +import os +import tempfile +from datetime import datetime +from pathlib import Path +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceTag, + Base, + Tag, +) +from app.assets.database.queries.asset_reference import ( + bulk_insert_references_ignore_conflicts, + get_references_for_prefixes, + get_unenriched_references, + restore_references_by_paths, +) +from app.assets.scanner import sync_references_with_filesystem +from app.assets.services.file_utils import get_mtime_ns + + +@pytest.fixture +def db_engine(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def session(db_engine): + with Session(db_engine) as sess: + yield sess + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +def _create_file(temp_dir: Path, name: str, content: bytes = b"\x00" * 100) -> str: + """Create a file and return its absolute path (no symlink resolution).""" + p = temp_dir / name + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(content) + return os.path.abspath(str(p)) + + +def _stat_mtime_ns(path: str) -> int: + return get_mtime_ns(os.stat(path, follow_symlinks=True)) + + +def _make_asset( + session: Session, + asset_id: str, + file_path: str, + ref_id: str, + *, + asset_hash: str | None = None, + size_bytes: int = 100, + mtime_ns: int | None = None, + needs_verify: bool = False, + is_missing: bool = False, +) -> tuple[Asset, AssetReference]: + """Insert an Asset + AssetReference and flush.""" + asset = session.get(Asset, asset_id) + if asset is None: + asset = Asset(id=asset_id, hash=asset_hash, size_bytes=size_bytes) + session.add(asset) + session.flush() + + ref = AssetReference( + id=ref_id, + asset_id=asset_id, + name=f"test-{ref_id}", + owner_id="system", + file_path=file_path, + mtime_ns=mtime_ns, + needs_verify=needs_verify, + is_missing=is_missing, + ) + session.add(ref) + session.flush() + return asset, ref + + +def _ensure_missing_tag(session: Session): + """Ensure the 'missing' tag exists.""" + if not session.get(Tag, "missing"): + session.add(Tag(name="missing", tag_type="system")) + session.flush() + + +class _VerifyCase: + def __init__(self, id, stat_unchanged, needs_verify_before, expect_needs_verify): + self.id = id + self.stat_unchanged = stat_unchanged + self.needs_verify_before = needs_verify_before + self.expect_needs_verify = expect_needs_verify + + +VERIFY_CASES = [ + _VerifyCase( + id="unchanged_clears_verify", + stat_unchanged=True, + needs_verify_before=True, + expect_needs_verify=False, + ), + _VerifyCase( + id="unchanged_keeps_clear", + stat_unchanged=True, + needs_verify_before=False, + expect_needs_verify=False, + ), + _VerifyCase( + id="changed_sets_verify", + stat_unchanged=False, + needs_verify_before=False, + expect_needs_verify=True, + ), + _VerifyCase( + id="changed_keeps_verify", + stat_unchanged=False, + needs_verify_before=True, + expect_needs_verify=True, + ), +] + + +@pytest.mark.parametrize("case", VERIFY_CASES, ids=lambda c: c.id) +def test_needs_verify_toggling(session, temp_dir, case): + """needs_verify is set/cleared based on mtime+size match.""" + fp = _create_file(temp_dir, "model.bin") + real_mtime = _stat_mtime_ns(fp) + + mtime_for_db = real_mtime if case.stat_unchanged else real_mtime + 1 + _make_asset( + session, "a1", fp, "r1", + asset_hash="blake3:abc", + mtime_ns=mtime_for_db, + needs_verify=case.needs_verify_before, + ) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.needs_verify is case.expect_needs_verify + + +class _MissingCase: + def __init__(self, id, file_exists, expect_is_missing): + self.id = id + self.file_exists = file_exists + self.expect_is_missing = expect_is_missing + + +MISSING_CASES = [ + _MissingCase(id="existing_file_not_missing", file_exists=True, expect_is_missing=False), + _MissingCase(id="missing_file_marked_missing", file_exists=False, expect_is_missing=True), +] + + +@pytest.mark.parametrize("case", MISSING_CASES, ids=lambda c: c.id) +def test_is_missing_flag(session, temp_dir, case): + """is_missing reflects whether the file exists on disk.""" + if case.file_exists: + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + else: + fp = str(temp_dir / "gone.bin") + mtime = 999 + + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.is_missing is case.expect_is_missing + + +def test_seed_asset_all_missing_deletes_asset(session, temp_dir): + """Seed asset with all refs missing gets deleted entirely.""" + fp = str(temp_dir / "gone.bin") + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + assert session.get(Asset, "seed1") is None + assert session.get(AssetReference, "r1") is None + + +def test_seed_asset_some_exist_returns_survivors(session, temp_dir): + """Seed asset with at least one existing ref survives and is returned.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + session.commit() + + assert session.get(Asset, "seed1") is not None + assert os.path.abspath(fp) in survivors + + +def test_hashed_asset_prunes_missing_refs_when_one_is_ok(session, temp_dir): + """Hashed asset with one stat-unchanged ref deletes missing refs.""" + fp_ok = _create_file(temp_dir, "good.bin") + fp_gone = str(temp_dir / "gone.bin") + mtime = _stat_mtime_ns(fp_ok) + + _make_asset(session, "h1", fp_ok, "r_ok", asset_hash="blake3:aaa", mtime_ns=mtime) + # Second ref on same asset, file missing + ref_gone = AssetReference( + id="r_gone", asset_id="h1", name="gone", + owner_id="system", file_path=fp_gone, mtime_ns=999, + ) + session.add(ref_gone) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + assert session.get(AssetReference, "r_ok") is not None + assert session.get(AssetReference, "r_gone") is None + + +def test_hashed_asset_all_missing_keeps_refs(session, temp_dir): + """Hashed asset with all refs missing keeps refs (no pruning).""" + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + assert session.get(AssetReference, "r1") is not None + ref = session.get(AssetReference, "r1") + assert ref.is_missing is True + + +def test_missing_tag_added_when_all_refs_gone(session, temp_dir): + """Missing tag is added to hashed asset when all refs are missing.""" + _ensure_missing_tag(session) + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=True, + ) + session.commit() + + session.expire_all() + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is not None + + +def test_missing_tag_removed_when_ref_ok(session, temp_dir): + """Missing tag is removed from hashed asset when a ref is stat-unchanged.""" + _ensure_missing_tag(session) + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=mtime) + # Pre-add a stale missing tag + session.add(AssetReferenceTag( + asset_reference_id="r1", tag_name="missing", origin="automatic", + )) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=True, + ) + session.commit() + + session.expire_all() + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is None + + +def test_missing_tags_not_touched_when_flag_false(session, temp_dir): + """Missing tags are not modified when update_missing_tags=False.""" + _ensure_missing_tag(session) + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=False, + ) + session.commit() + + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is None # tag was never added + + +def test_returns_none_when_collect_false(session, temp_dir): + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + result = sync_references_with_filesystem( + session, "models", collect_existing_paths=False, + ) + + assert result is None + + +def test_returns_empty_set_for_no_prefixes(session): + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[]): + result = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + + assert result == set() + + +def test_no_references_is_noop(session, temp_dir): + """No crash and no side effects when there are no references.""" + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + session.commit() + + assert survivors == set() + + +# --------------------------------------------------------------------------- +# Soft-delete persistence across scanner operations +# --------------------------------------------------------------------------- + +def _soft_delete_ref(session: Session, ref_id: str) -> None: + """Mark a reference as soft-deleted (mimics the API DELETE behaviour).""" + ref = session.get(AssetReference, ref_id) + ref.deleted_at = datetime(2025, 1, 1) + session.flush() + + +def test_soft_deleted_ref_excluded_from_get_references_for_prefixes(session, temp_dir): + """get_references_for_prefixes skips soft-deleted references.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + rows = get_references_for_prefixes(session, [str(temp_dir)], include_missing=True) + assert len(rows) == 0 + + +def test_sync_does_not_resurrect_soft_deleted_ref(session, temp_dir): + """Scanner sync leaves soft-deleted refs untouched even when file exists on disk.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.deleted_at is not None, "soft-deleted ref must stay deleted after sync" + + +def test_bulk_insert_does_not_overwrite_soft_deleted_ref(session, temp_dir): + """bulk_insert_references_ignore_conflicts cannot replace a soft-deleted row.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + now = datetime.now(tz=None) + bulk_insert_references_ignore_conflicts(session, [ + { + "id": "r_new", + "asset_id": "a1", + "file_path": fp, + "name": "model.bin", + "owner_id": "", + "mtime_ns": mtime, + "preview_id": None, + "user_metadata": None, + "created_at": now, + "updated_at": now, + "last_access_time": now, + } + ]) + session.commit() + + session.expire_all() + # Original row is still the soft-deleted one + ref = session.get(AssetReference, "r1") + assert ref is not None + assert ref.deleted_at is not None + # The new row was not inserted (conflict on file_path) + assert session.get(AssetReference, "r_new") is None + + +def test_restore_references_by_paths_skips_soft_deleted(session, temp_dir): + """restore_references_by_paths does not clear is_missing on soft-deleted refs.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset( + session, "a1", fp, "r1", + asset_hash="blake3:abc", mtime_ns=mtime, is_missing=True, + ) + _soft_delete_ref(session, "r1") + session.commit() + + restored = restore_references_by_paths(session, [fp]) + session.commit() + + assert restored == 0 + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.is_missing is True, "is_missing must not be cleared on soft-deleted ref" + assert ref.deleted_at is not None + + +def test_get_unenriched_references_excludes_soft_deleted(session, temp_dir): + """Enrichment queries do not pick up soft-deleted references.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + rows = get_unenriched_references(session, [str(temp_dir)], max_level=2) + assert len(rows) == 0 + + +def test_sync_ignores_soft_deleted_seed_asset(session, temp_dir): + """Soft-deleted seed ref is not garbage-collected even when file is missing.""" + fp = str(temp_dir / "gone.bin") # file does not exist + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999) + _soft_delete_ref(session, "r1") + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + # Asset and ref must still exist — scanner did not see the soft-deleted row + assert session.get(Asset, "seed1") is not None + assert session.get(AssetReference, "r1") is not None diff --git a/tests-unit/assets_test/test_tags.py b/tests-unit/assets_test/test_tags_api.py similarity index 98% rename from tests-unit/assets_test/test_tags.py rename to tests-unit/assets_test/test_tags_api.py index 6b1047802..595bf29c6 100644 --- a/tests-unit/assets_test/test_tags.py +++ b/tests-unit/assets_test/test_tags_api.py @@ -69,8 +69,8 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, used_names = [t["name"] for t in body2["tags"]] assert custom_tag in used_names - # Delete the asset so the tag usage drops to zero - rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120) + # Hard-delete the asset so the tag usage drops to zero + rd = http.delete(f"{api_base}/api/assets/{_asset['id']}?delete_content=true", timeout=120) assert rd.status_code == 204 # Now the custom tag must not be returned when include_zero=false diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py index 137d7391a..d68e5b5d7 100644 --- a/tests-unit/assets_test/test_uploads.py +++ b/tests-unit/assets_test/test_uploads.py @@ -18,25 +18,25 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma assert r1.status_code == 201, a1 assert a1["created_new"] is True - # Second upload with the same data and name should return created_new == False and the same asset + # Second upload with the same data and name creates a new AssetReference (duplicates allowed) + # Returns 200 because Asset already exists, but a new AssetReference is created files = {"file": (name, data, "application/octet-stream")} form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) a2 = r2.json() - assert r2.status_code == 200, a2 - assert a2["created_new"] is False + assert r2.status_code in (200, 201), a2 assert a2["asset_hash"] == a1["asset_hash"] - assert a2["id"] == a1["id"] # old reference + assert a2["id"] != a1["id"] # new reference with same content - # Third upload with the same data but new name should return created_new == False and the new AssetReference + # Third upload with the same data but different name also creates new AssetReference files = {"file": (name, data, "application/octet-stream")} form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)} - r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) - a3 = r2.json() - assert r2.status_code == 200, a3 - assert a3["created_new"] is False + r3 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a3 = r3.json() + assert r3.status_code in (200, 201), a3 assert a3["asset_hash"] == a1["asset_hash"] - assert a3["id"] != a1["id"] # old reference + assert a3["id"] != a1["id"] + assert a3["id"] != a2["id"] def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str): @@ -116,7 +116,7 @@ def test_concurrent_upload_identical_bytes_different_names( ): """ Two concurrent uploads of identical bytes but different names. - Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True. + Expect a single Asset (same hash), two AssetReference rows, and exactly one created_new=True. """ scope = f"concupload-{uuid.uuid4().hex[:6]}" name1, name2 = "cu_a.bin", "cu_b.bin" diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index 2355b8000..3a6790ee0 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -2,4 +2,3 @@ pytest>=7.8.0 pytest-aiohttp pytest-asyncio websocket-client -blake3 diff --git a/tests-unit/seeder_test/test_seeder.py b/tests-unit/seeder_test/test_seeder.py new file mode 100644 index 000000000..db3795e48 --- /dev/null +++ b/tests-unit/seeder_test/test_seeder.py @@ -0,0 +1,900 @@ +"""Unit tests for the _AssetSeeder background scanning class.""" + +import threading +from unittest.mock import patch + +import pytest + +from app.assets.database.queries.asset_reference import UnenrichedReferenceRow +from app.assets.seeder import _AssetSeeder, Progress, ScanInProgressError, ScanPhase, State + + +@pytest.fixture +def fresh_seeder(): + """Create a fresh _AssetSeeder instance for testing.""" + seeder = _AssetSeeder() + yield seeder + seeder.shutdown(timeout=1.0) + + +@pytest.fixture +def mock_dependencies(): + """Mock all external dependencies for isolated testing.""" + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + yield + + +class TestSeederStateTransitions: + """Test state machine transitions.""" + + def test_initial_state_is_idle(self, fresh_seeder: _AssetSeeder): + assert fresh_seeder.get_status().state == State.IDLE + + def test_start_transitions_to_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + started = fresh_seeder.start(roots=("models",)) + assert started is True + assert reached.wait(timeout=2.0) + assert fresh_seeder.get_status().state == State.RUNNING + + barrier.set() + + def test_start_while_running_returns_false( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + second_start = fresh_seeder.start(roots=("models",)) + assert second_start is False + + barrier.set() + + def test_cancel_transitions_to_cancelling( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + cancelled = fresh_seeder.cancel() + assert cancelled is True + assert fresh_seeder.get_status().state == State.CANCELLING + + barrier.set() + + def test_cancel_when_idle_returns_false(self, fresh_seeder: _AssetSeeder): + cancelled = fresh_seeder.cancel() + assert cancelled is False + + def test_state_returns_to_idle_after_completion( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=5.0) + assert completed is True + assert fresh_seeder.get_status().state == State.IDLE + + +class TestSeederWait: + """Test wait() behavior.""" + + def test_wait_blocks_until_complete( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=5.0) + assert completed is True + assert fresh_seeder.get_status().state == State.IDLE + + def test_wait_returns_false_on_timeout( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=10.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=0.1) + assert completed is False + + barrier.set() + + def test_wait_when_idle_returns_true(self, fresh_seeder: _AssetSeeder): + completed = fresh_seeder.wait(timeout=1.0) + assert completed is True + + +class TestSeederProgress: + """Test progress tracking.""" + + def test_get_status_returns_progress_during_scan( + self, fresh_seeder: _AssetSeeder + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_build(*args, **kwargs): + reached.set() + barrier.wait(timeout=5.0) + return ([], set(), 0) + + paths = ["/path/file1.safetensors", "/path/file2.safetensors"] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), + patch("app.assets.seeder.build_asset_specs", side_effect=slow_build), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + status = fresh_seeder.get_status() + assert status.state == State.RUNNING + assert status.progress is not None + assert status.progress.total == 2 + + barrier.set() + + def test_progress_callback_is_invoked( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + progress_updates: list[Progress] = [] + + def callback(p: Progress): + progress_updates.append(p) + + with patch( + "app.assets.seeder.collect_paths_for_roots", + return_value=[f"/path/file{i}.safetensors" for i in range(10)], + ): + fresh_seeder.start(roots=("models",), progress_callback=callback) + fresh_seeder.wait(timeout=5.0) + + assert len(progress_updates) > 0 + + +class TestSeederCancellation: + """Test cancellation behavior.""" + + def test_scan_commits_partial_progress_on_cancellation( + self, fresh_seeder: _AssetSeeder + ): + insert_count = 0 + barrier = threading.Event() + first_insert_done = threading.Event() + + def slow_insert(specs, tags): + nonlocal insert_count + insert_count += 1 + if insert_count == 1: + first_insert_done.set() + if insert_count >= 2: + barrier.wait(timeout=5.0) + return len(specs) + + paths = [f"/path/file{i}.safetensors" for i in range(1500)] + specs = [ + { + "abs_path": p, + "size_bytes": 100, + "mtime_ns": 0, + "info_name": f"file{i}", + "tags": [], + "fname": f"file{i}", + "metadata": None, + "hash": None, + "mime_type": None, + } + for i, p in enumerate(paths) + ] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), + patch( + "app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0) + ), + patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + assert first_insert_done.wait(timeout=2.0) + + fresh_seeder.cancel() + barrier.set() + fresh_seeder.wait(timeout=5.0) + + assert 1 <= insert_count < 3 # 1500 paths / 500 batch = 3; cancel stopped early + + +class TestSeederErrorHandling: + """Test error handling behavior.""" + + def test_database_errors_captured_in_status(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch( + "app.assets.seeder.collect_paths_for_roots", + return_value=["/path/file.safetensors"], + ), + patch( + "app.assets.seeder.build_asset_specs", + return_value=( + [ + { + "abs_path": "/path/file.safetensors", + "size_bytes": 100, + "mtime_ns": 0, + "info_name": "file", + "tags": [], + "fname": "file", + "metadata": None, + "hash": None, + "mime_type": None, + } + ], + set(), + 0, + ), + ), + patch( + "app.assets.seeder.insert_asset_specs", + side_effect=Exception("DB connection failed"), + ), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert len(status.errors) > 0 + assert "DB connection failed" in status.errors[0] + + def test_dependencies_unavailable_captured_in_errors( + self, fresh_seeder: _AssetSeeder + ): + with patch("app.assets.seeder.dependencies_available", return_value=False): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert len(status.errors) > 0 + assert "dependencies" in status.errors[0].lower() + + def test_thread_crash_resets_state_to_idle(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch( + "app.assets.seeder.sync_root_safely", + side_effect=RuntimeError("Unexpected crash"), + ), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert status.state == State.IDLE + assert len(status.errors) > 0 + + +class TestSeederThreadSafety: + """Test thread safety of concurrent operations.""" + + def test_concurrent_start_calls_spawn_only_one_thread( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + results = [] + + def try_start(): + results.append(fresh_seeder.start(roots=("models",))) + + threads = [threading.Thread(target=try_start) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + barrier.set() + + assert sum(results) == 1 + + def test_get_status_safe_during_scan( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + statuses = [] + for _ in range(100): + statuses.append(fresh_seeder.get_status()) + + barrier.set() + + assert all( + s.state in (State.RUNNING, State.IDLE, State.CANCELLING) + for s in statuses + ) + + +class TestSeederMarkMissing: + """Test mark_missing_outside_prefixes behavior.""" + + def test_mark_missing_when_idle(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch( + "app.assets.seeder.get_all_known_prefixes", + return_value=["/models", "/input", "/output"], + ), + patch( + "app.assets.seeder.mark_missing_outside_prefixes_safely", return_value=5 + ) as mock_mark, + ): + result = fresh_seeder.mark_missing_outside_prefixes() + assert result == 5 + mock_mark.assert_called_once_with(["/models", "/input", "/output"]) + + def test_mark_missing_raises_when_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + with pytest.raises(ScanInProgressError): + fresh_seeder.mark_missing_outside_prefixes() + + barrier.set() + + def test_mark_missing_returns_zero_when_dependencies_unavailable( + self, fresh_seeder: _AssetSeeder + ): + with patch("app.assets.seeder.dependencies_available", return_value=False): + result = fresh_seeder.mark_missing_outside_prefixes() + assert result == 0 + + def test_prune_first_flag_triggers_mark_missing_before_scan( + self, fresh_seeder: _AssetSeeder + ): + call_order = [] + + def track_mark(prefixes): + call_order.append("mark_missing") + return 3 + + def track_sync(root): + call_order.append(f"sync_{root}") + return set() + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]), + patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark), + patch("app.assets.seeder.sync_root_safely", side_effect=track_sync), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",), prune_first=True) + fresh_seeder.wait(timeout=5.0) + + assert call_order[0] == "mark_missing" + assert "sync_models" in call_order + + +class TestSeederPhases: + """Test phased scanning behavior.""" + + def test_start_fast_only_runs_fast_phase(self, fresh_seeder: _AssetSeeder): + """Verify start_fast only runs the fast phase.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start_fast(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 1 + assert len(enrich_called) == 0 + + def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: _AssetSeeder): + """Verify start_enrich only runs the enrich phase.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start_enrich(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 0 + assert len(enrich_called) == 1 + + def test_full_scan_runs_both_phases(self, fresh_seeder: _AssetSeeder): + """Verify full scan runs both fast and enrich phases.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.FULL) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 1 + assert len(enrich_called) == 1 + + +class TestSeederPauseResume: + """Test pause/resume behavior.""" + + def test_pause_transitions_to_paused( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + paused = fresh_seeder.pause() + assert paused is True + assert fresh_seeder.get_status().state == State.PAUSED + + barrier.set() + + def test_pause_when_idle_returns_false(self, fresh_seeder: _AssetSeeder): + paused = fresh_seeder.pause() + assert paused is False + + def test_resume_returns_to_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + fresh_seeder.pause() + assert fresh_seeder.get_status().state == State.PAUSED + + resumed = fresh_seeder.resume() + assert resumed is True + assert fresh_seeder.get_status().state == State.RUNNING + + barrier.set() + + def test_resume_when_not_paused_returns_false( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + resumed = fresh_seeder.resume() + assert resumed is False + + barrier.set() + + def test_cancel_while_paused_works( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached_checkpoint = threading.Event() + + def slow_collect(*args): + reached_checkpoint.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached_checkpoint.wait(timeout=2.0) + + fresh_seeder.pause() + assert fresh_seeder.get_status().state == State.PAUSED + + cancelled = fresh_seeder.cancel() + assert cancelled is True + + barrier.set() + fresh_seeder.wait(timeout=5.0) + assert fresh_seeder.get_status().state == State.IDLE + +class TestSeederStopRestart: + """Test stop and restart behavior.""" + + def test_stop_is_alias_for_cancel( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + stopped = fresh_seeder.stop() + assert stopped is True + assert fresh_seeder.get_status().state == State.CANCELLING + + barrier.set() + + def test_restart_cancels_and_starts_new_scan( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + start_count = 0 + + def slow_collect(*args): + nonlocal start_count + start_count += 1 + if start_count == 1: + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + barrier.set() + restarted = fresh_seeder.restart() + assert restarted is True + + fresh_seeder.wait(timeout=5.0) + assert start_count == 2 + + def test_restart_preserves_previous_params(self, fresh_seeder: _AssetSeeder): + """Verify restart uses previous params when not overridden.""" + collected_roots = [] + + def track_collect(roots): + collected_roots.append(roots) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("input", "output")) + fresh_seeder.wait(timeout=5.0) + + fresh_seeder.restart() + fresh_seeder.wait(timeout=5.0) + + assert len(collected_roots) == 2 + assert collected_roots[0] == ("input", "output") + assert collected_roots[1] == ("input", "output") + + def test_restart_can_override_params(self, fresh_seeder: _AssetSeeder): + """Verify restart can override previous params.""" + collected_roots = [] + + def track_collect(roots): + collected_roots.append(roots) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + fresh_seeder.restart(roots=("input",)) + fresh_seeder.wait(timeout=5.0) + + assert len(collected_roots) == 2 + assert collected_roots[0] == ("models",) + assert collected_roots[1] == ("input",) + + +def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow: + return UnenrichedReferenceRow( + reference_id=ref_id, asset_id=asset_id, + file_path=f"/fake/{ref_id}.bin", enrichment_level=0, + ) + + +class TestEnrichPhaseDefensiveLogic: + """Test skip_ids filtering and consecutive_empty termination.""" + + def test_failed_refs_are_skipped_on_subsequent_batches( + self, fresh_seeder: _AssetSeeder, + ): + """References that fail enrichment are filtered out of future batches.""" + row_a = _make_row("r1") + row_b = _make_row("r2") + call_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 2: + return [row_a, row_b] + return [] + + enriched_refs: list[list[str]] = [] + + def fake_enrich(rows, **kwargs): + ref_ids = [r.reference_id for r in rows] + enriched_refs.append(ref_ids) + # r1 always fails, r2 succeeds + failed = [r.reference_id for r in rows if r.reference_id == "r1"] + enriched = len(rows) - len(failed) + return enriched, failed + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # First batch: both refs attempted + assert "r1" in enriched_refs[0] + assert "r2" in enriched_refs[0] + # Second batch: r1 filtered out + assert "r1" not in enriched_refs[1] + assert "r2" in enriched_refs[1] + + def test_stops_after_consecutive_empty_batches( + self, fresh_seeder: _AssetSeeder, + ): + """Enrich phase terminates after 3 consecutive batches with zero progress.""" + row = _make_row("r1") + batch_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal batch_count + batch_count += 1 + # Always return the same row (simulating a permanently failing ref) + return [row] + + def fake_enrich(rows, **kwargs): + # Always fail — zero enriched, all failed + return 0, [r.reference_id for r in rows] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # Should stop after exactly 3 consecutive empty batches + # Batch 1: returns row, enrich fails → filtered out in batch 2+ + # But get_unenriched keeps returning it, filter removes it → empty → break + # Actually: batch 1 has row, fails. Batch 2 get_unenriched returns [row], + # skip_ids filters it → empty list → breaks via `if not unenriched: break` + # So it terminates in 2 calls to get_unenriched. + assert batch_count == 2 + + def test_consecutive_empty_counter_resets_on_success( + self, fresh_seeder: _AssetSeeder, + ): + """A successful batch resets the consecutive empty counter.""" + call_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 6: + return [_make_row(f"r{call_count}", f"a{call_count}")] + return [] + + def fake_enrich(rows, **kwargs): + ref_id = rows[0].reference_id + # Fail batches 1-2, succeed batch 3, fail batches 4-5, succeed batch 6 + if ref_id in ("r1", "r2", "r4", "r5"): + return 0, [ref_id] + return 1, [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # All 6 batches should run + 1 final call returning empty + assert call_count == 7 + status = fresh_seeder.get_status() + assert status.state == State.IDLE diff --git a/utils/mime_types.py b/utils/mime_types.py new file mode 100644 index 000000000..916e963c5 --- /dev/null +++ b/utils/mime_types.py @@ -0,0 +1,37 @@ +"""Centralized MIME type initialization. + +Call init_mime_types() once at startup to initialize the MIME type database +and register all custom types used across ComfyUI. +""" + +import mimetypes + +_initialized = False + + +def init_mime_types(): + """Initialize the MIME type database and register all custom types. + + Safe to call multiple times; only runs once. + """ + global _initialized + if _initialized: + return + _initialized = True + + mimetypes.init() + + # Web types (used by server.py for static file serving) + mimetypes.add_type('application/javascript; charset=utf-8', '.js') + mimetypes.add_type('image/webp', '.webp') + + # Model and data file types (used by asset scanning / metadata extraction) + mimetypes.add_type("application/safetensors", ".safetensors") + mimetypes.add_type("application/safetensors", ".sft") + mimetypes.add_type("application/pytorch", ".pt") + mimetypes.add_type("application/pytorch", ".pth") + mimetypes.add_type("application/pickle", ".ckpt") + mimetypes.add_type("application/pickle", ".pkl") + mimetypes.add_type("application/gguf", ".gguf") + mimetypes.add_type("application/yaml", ".yaml") + mimetypes.add_type("application/yaml", ".yml") From 7723f20bbe010a3ea4eac602f77b0ff496f123c4 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 8 Mar 2026 13:17:40 -0700 Subject: [PATCH 58/75] comfy-aimdo 0.2.9 (#12840) Comfy-aimdo 0.2.9 fixes a context issue where if a non-main thread does a spurious garbage collection, cudaFrees are attempted with bad context. Some new APIs for displaying aimdo stats in UI widgets are also added. These are purely additive getters that dont touch cuda APIs. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9527135ec..b1db1cf24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy filelock av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.7 +comfy-aimdo>=0.2.9 requests simpleeval>=1.0.0 blake3 From e4b0bb8305a4069ef7ff8396bfc6057c736ab95b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 8 Mar 2026 13:25:30 -0700 Subject: [PATCH 59/75] Import assets seeder later, print some package versions. (#12841) --- app/assets/services/hashing.py | 6 +++++- main.py | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/app/assets/services/hashing.py b/app/assets/services/hashing.py index 92aee6402..41d8b4615 100644 --- a/app/assets/services/hashing.py +++ b/app/assets/services/hashing.py @@ -3,8 +3,12 @@ import os from contextlib import contextmanager from dataclasses import dataclass from typing import IO, Any, Callable, Iterator +import logging -from blake3 import blake3 +try: + from blake3 import blake3 +except ModuleNotFoundError: + logging.warning("WARNING: blake3 package not installed") DEFAULT_CHUNK = 8 * 1024 * 1024 diff --git a/main.py b/main.py index a8fc1a28d..1977f9362 100644 --- a/main.py +++ b/main.py @@ -3,11 +3,11 @@ comfy.options.enable_args_parsing() import os import importlib.util +import importlib.metadata import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger -from app.assets.seeder import asset_seeder import itertools import utils.extra_config from utils.mime_types import init_mime_types @@ -182,6 +182,7 @@ if 'torch' in sys.modules: import comfy.utils +from app.assets.seeder import asset_seeder import execution import server @@ -451,6 +452,11 @@ if __name__ == "__main__": # Running directly, just start ComfyUI. logging.info("Python version: {}".format(sys.version)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + for package in ("comfy-aimdo", "comfy-kitchen"): + try: + logging.info("{} version: {}".format(package, importlib.metadata.version(package))) + except: + pass if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") From 06f85e2c792c626f2cab3cb4f94cd30d43e9347b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Mon, 9 Mar 2026 22:08:51 +0200 Subject: [PATCH 60/75] Fix text encoder lora loading for wrapped models (#12852) --- comfy/lora.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index f36ddb046..63ee85323 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}): for k in sdk: if k.endswith(".weight"): key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models + if tp > 0 and not k.startswith("clip_"): + key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False From 814dab9f4636df22a36cbbad21e35ac7609a0ef2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 10 Mar 2026 10:03:22 +0800 Subject: [PATCH 61/75] Update workflow templates to v0.9.18 (#12857) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b1db1cf24..bb58f8d01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.11 +comfyui-workflow-templates==0.9.18 comfyui-embedded-docs==0.4.3 torch torchsde From 740d998c9cc821ca0a72b5b5d4b17aba1aec6b44 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:49:31 +0900 Subject: [PATCH 62/75] fix(manager): improve install guidance when comfyui-manager is not installed (#12810) --- main.py | 13 ++++++++++--- manager_requirements.txt | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 1977f9362..83a7244db 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ comfy.options.enable_args_parsing() import os import importlib.util +import shutil import importlib.metadata import folder_paths import time @@ -64,8 +65,15 @@ if __name__ == "__main__": def handle_comfyui_manager_unavailable(): - if not args.windows_standalone_build: - logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt") + uv_available = shutil.which("uv") is not None + + pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}" + msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}" + if uv_available: + msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}" + msg += "\n" + logging.warning(msg) args.enable_manager = False @@ -173,7 +181,6 @@ execute_prestartup_script() # Main code import asyncio -import shutil import threading import gc diff --git a/manager_requirements.txt b/manager_requirements.txt index c420cc48e..6bcc3fb50 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b1 +comfyui_manager==4.1b2 \ No newline at end of file From c4fb0271cd7fbddb2381372b1f7c1206d1dd58fc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 9 Mar 2026 20:37:58 -0700 Subject: [PATCH 63/75] Add a way for nodes to add pre attn patches to flux model. (#12861) --- comfy/ldm/flux/layers.py | 15 ++++++++++++++- comfy/ldm/flux/math.py | 2 ++ comfy/ldm/flux/model.py | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 8b3f500d7..e20d498f8 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module): del txt_k, img_k v = torch.cat((txt_v, img_v), dim=2) del txt_v, img_v + + extra_options["img_slice"] = [txt.shape[1], q.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # run actual attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v if "attn1_output_patch" in transformer_patches: - extra_options["img_slice"] = [txt.shape[1], attn.shape[1]] patch = transformer_patches["attn1_output_patch"] for p in patch: attn = p(attn, extra_options) @@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module): del qkv q, k = self.norm(q, k, v) + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 5e764bb46..824daf5e6 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def _apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]: + freqs_cis = freqs_cis[:, :, :x_.shape[2]] x_out = freqs_cis[..., 0] * x_[..., 0] x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4dcf7c5..00f12c031 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -170,7 +170,7 @@ class Flux(nn.Module): if "post_input" in patches: for p in patches["post_input"]: - out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) img = out["img"] txt = out["txt"] img_ids = out["img_ids"] From a912809c252f5a2d69c8ab4035fc262a578fdcee Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 9 Mar 2026 20:50:10 -0700 Subject: [PATCH 64/75] model_detection: deep clone pre edited edited weights (#12862) Deep clone these weights as needed to avoid segfaulting when it tries to touch the original mmap. --- comfy/model_detection.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 6eace4628..35a6822e3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,4 +1,5 @@ import json +import comfy.memory_management import comfy.supported_models import comfy.supported_models_base import comfy.utils @@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): new[:old_weight.shape[0]] = old_weight old_weight = new + if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled: + old_weight = old_weight.clone() + w = old_weight.narrow(offset[0], offset[1], offset[2]) else: + if comfy.memory_management.aimdo_enabled: + weight = weight.clone() old_weight = weight w = weight w[:] = fun(weight) From 535c16ce6e3d2634d6eb2fd17ecccb8d497e26a0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 9 Mar 2026 21:41:02 -0700 Subject: [PATCH 65/75] Widen OOM_EXCEPTION to AcceleratorError form (#12835) Pytorch only filters for OOMs in its own allocators however there are paths that can OOM on allocators made outside the pytorch allocators. These manifest as an AllocatorError as pytorch does not have universal error translation to its OOM type on exception. Handle it. A log I have for this also shows a double report of the error async, so call the async discarder to cleanup and make these OOMs look like OOMs. --- comfy/ldm/modules/attention.py | 3 ++- comfy/ldm/modules/diffusionmodules/model.py | 6 ++++-- comfy/ldm/modules/sub_quadratic_attention.py | 3 ++- comfy/model_management.py | 12 ++++++++++++ comfy/sd.py | 6 ++++-- comfy_extras/nodes_upscale_model.py | 3 ++- execution.py | 2 +- 7 files changed, 27 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 10d051325..b193fe5e8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) if first_op_done == False: model_management.soft_empty_cache(True) if cleared_cache == False: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 805592aa5..fcbaa074f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -258,7 +258,8 @@ def slice_attention(q, k, v): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) model_management.soft_empty_cache(True) steps *= 2 if steps > 128: @@ -314,7 +315,8 @@ def pytorch_attention(q, k, v): try: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") oom_fallback = True if oom_fallback: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index fab145f1c..f982afc2b 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined torch.exp(attn_scores, out=attn_scores) diff --git a/comfy/model_management.py b/comfy/model_management.py index 07bc8ad67..81550c790 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -270,6 +270,18 @@ try: except: OOM_EXCEPTION = Exception +def is_oom(e): + if isinstance(e, OOM_EXCEPTION): + return True + if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2: + discard_cuda_async_error() + return True + return False + +def raise_non_oom(e): + if not is_oom(e): + raise e + XFORMERS_VERSION = "" XFORMERS_ENABLED_VAE = True if args.disable_xformers: diff --git a/comfy/sd.py b/comfy/sd.py index 888ef1e77..adcd67767 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -954,7 +954,8 @@ class VAE: if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. @@ -1029,7 +1030,8 @@ class VAE: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples[x:x + batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 97b9e948d..db4f9d231 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode): pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) oom = False - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) tile //= 2 if tile < 128: raise e diff --git a/execution.py b/execution.py index 7ccdbf93e..a7791efed 100644 --- a/execution.py +++ b/execution.py @@ -612,7 +612,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, logging.error(traceback.format_exc()) tips = "" - if isinstance(ex, comfy.model_management.OOM_EXCEPTION): + if comfy.model_management.is_oom(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") From 8086468d2a1a5a6ed70fea3391e7fb9248ebc7da Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 10 Mar 2026 09:05:31 -0700 Subject: [PATCH 66/75] main: switch on faulthandler (#12868) When we get segfault bug reports we dont get much. Switch on pythons inbuilt tracer for segfault. --- main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main.py b/main.py index 83a7244db..8905fd09a 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ from app.logger import setup_logger import itertools import utils.extra_config from utils.mime_types import init_mime_types +import faulthandler import logging import sys from comfy_execution.progress import get_progress_state @@ -26,6 +27,8 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +faulthandler.enable(file=sys.stderr, all_threads=False) + import comfy_aimdo.control if enables_dynamic_vram(): From 3ad36d6be66b2af2a7c3dc9ab6936eebc6b98075 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 10 Mar 2026 17:09:12 -0700 Subject: [PATCH 67/75] Allow model patches to have a cleanup function. (#12878) The function gets called after sampling is finished. --- comfy/model_patcher.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 745384271..bc3a8f446 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -599,6 +599,27 @@ class ModelPatcher: return models + def model_patches_call_function(self, function_name="cleanup", arguments={}): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], function_name): + getattr(patch_list[i], function_name)(**arguments) + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], function_name): + getattr(patch_list[k], function_name)(**arguments) + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] + if hasattr(wrap_func, function_name): + getattr(wrap_func, function_name)(**arguments) + def model_dtype(self): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() @@ -1062,6 +1083,7 @@ class ModelPatcher: return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) def cleanup(self): + self.model_patches_call_function(function_name="cleanup") self.clean_hooks() if hasattr(self.model, "current_patcher"): self.model.current_patcher = None From 9642e4407b60b291744cc1d34501783cff6702e5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 10 Mar 2026 21:09:35 -0700 Subject: [PATCH 68/75] Add pre attention and post input patches to qwen image model. (#12879) --- comfy/ldm/qwen_image/model.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 6eb744286..0862f72f7 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -149,6 +149,9 @@ class Attention(nn.Module): seq_img = hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1] + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + # Project and reshape to BHND format (batch, heads, seq, dim) img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() @@ -167,15 +170,22 @@ class Attention(nn.Module): joint_key = torch.cat([txt_key, img_key], dim=2) joint_value = torch.cat([txt_value, img_value], dim=2) - joint_query = apply_rope1(joint_query, image_rotary_emb) - joint_key = apply_rope1(joint_key, image_rotary_emb) - if encoder_hidden_states_mask is not None: attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask else: attn_mask = None + extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options) + joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask) + + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attn_mask, transformer_options=transformer_options, skip_reshape=True) @@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module): timestep_zero_index = None if ref_latents is not None: + ref_num_tokens = [] h = 0 w = 0 index = 0 @@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module): kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + ref_num_tokens.append(kontext.shape[1]) if timestep_zero: if index > 0: timestep = torch.cat([timestep, timestep * 0], dim=0) timestep_zero_index = num_embeds + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() - del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) @@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) + if "post_input" in patches: + for p in patches["post_input"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + img_ids = out["img_ids"] + txt_ids = out["txt_ids"] + + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() + del ids, txt_ids, img_ids + transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): From 980621da83267beffcb84839a27101b7092256e7 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 11 Mar 2026 08:49:38 -0700 Subject: [PATCH 69/75] comfy-aimdo 0.2.10 (#12890) Comfy Aimdo 0.2.10 fixes the aimdo allocator hook for legacy cudaMalloc consumers. Some consumers of cudaMalloc assume implicit synchronization built in closed source logic inside cuda. This is preserved by passing through to cuda as-is and accouting after the fact as opposed to integrating these hooks with Aimdos VMA based allocator. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bb58f8d01..89cd994e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy filelock av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.9 +comfy-aimdo>=0.2.10 requests simpleeval>=1.0.0 blake3 From 3365008dfe5a7a46cbe76d8ad0d7efb054617733 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 11 Mar 2026 18:53:55 +0200 Subject: [PATCH 70/75] feat(api-nodes): add Reve Image nodes (#12848) --- comfy_api_nodes/apis/reve.py | 68 ++++++ comfy_api_nodes/nodes_reve.py | 395 +++++++++++++++++++++++++++++++++ comfy_api_nodes/util/client.py | 12 +- 3 files changed, 474 insertions(+), 1 deletion(-) create mode 100644 comfy_api_nodes/apis/reve.py create mode 100644 comfy_api_nodes/nodes_reve.py diff --git a/comfy_api_nodes/apis/reve.py b/comfy_api_nodes/apis/reve.py new file mode 100644 index 000000000..c6b5a69d8 --- /dev/null +++ b/comfy_api_nodes/apis/reve.py @@ -0,0 +1,68 @@ +from pydantic import BaseModel, Field + + +class RevePostprocessingOperation(BaseModel): + process: str = Field(..., description="The postprocessing operation: upscale or remove_background.") + upscale_factor: int | None = Field( + None, + description="Upscale factor (2, 3, or 4). Only used when process is upscale.", + ge=2, + le=4, + ) + + +class ReveImageCreateRequest(BaseModel): + prompt: str = Field(...) + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageEditRequest(BaseModel): + edit_instruction: str = Field(...) + reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.") + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int | None = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageRemixRequest(BaseModel): + prompt: str = Field(...) + reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.") + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int | None = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageResponse(BaseModel): + image: str | None = Field(None, description="The base64 encoded image data.") + request_id: str | None = Field(None, description="A unique id for the request.") + credits_used: float | None = Field(None, description="The number of credits used for this request.") + version: str | None = Field(None, description="The specific model version used.") + content_violation: bool | None = Field( + None, description="Indicates whether the generated image violates the content policy." + ) diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py new file mode 100644 index 000000000..608d9f058 --- /dev/null +++ b/comfy_api_nodes/nodes_reve.py @@ -0,0 +1,395 @@ +from io import BytesIO + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.reve import ( + ReveImageCreateRequest, + ReveImageEditRequest, + ReveImageRemixRequest, + RevePostprocessingOperation, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + bytesio_to_image_tensor, + sync_op_raw, + tensor_to_base64_string, + validate_string, +) + + +def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None: + ops = [] + if upscale["upscale"] == "enabled": + ops.append( + RevePostprocessingOperation( + process="upscale", + upscale_factor=upscale["upscale_factor"], + ) + ) + if remove_background: + ops.append(RevePostprocessingOperation(process="remove_background")) + return ops or None + + +def _postprocessing_inputs(): + return [ + IO.DynamicCombo.Input( + "upscale", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Int.Input( + "upscale_factor", + default=2, + min=2, + max=4, + step=1, + tooltip="Upscale factor (2x, 3x, or 4x).", + ), + ], + ), + ], + tooltip="Upscale the generated image. May add additional cost.", + ), + IO.Boolean.Input( + "remove_background", + default=False, + tooltip="Remove the background from the generated image. May add additional cost.", + ), + ] + + +def _reve_price_extractor(headers: dict) -> float | None: + credits_used = headers.get("x-reve-credits-used") + if credits_used is not None: + return float(credits_used) / 524.48 + return None + + +def _reve_response_header_validator(headers: dict) -> None: + error_code = headers.get("x-reve-error-code") + if error_code: + raise ValueError(f"Reve API error: {error_code}") + if headers.get("x-reve-content-violation", "").lower() == "true": + raise ValueError("The generated image was flagged for content policy violation.") + + +def _model_inputs(versions: list[str], aspect_ratios: list[str]): + return [ + IO.DynamicCombo.Option( + version, + [ + IO.Combo.Input( + "aspect_ratio", + options=aspect_ratios, + tooltip="Aspect ratio of the output image.", + ), + IO.Int.Input( + "test_time_scaling", + default=1, + min=1, + max=5, + step=1, + tooltip="Higher values produce better images but cost more credits.", + advanced=True, + ), + ], + ) + for version in versions + ] + + +class ReveImageCreateNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageCreateNode", + display_name="Reve Image Create", + category="api node/image/Reve", + description="Generate images from text descriptions using Reve.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. Maximum 2560 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-create@20250915"], + aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for generation.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""", + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2560) + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/create", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageCreateRequest( + prompt=prompt, + aspect_ratio=model["aspect_ratio"], + version=model["model"], + test_time_scaling=model["test_time_scaling"], + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageEditNode", + display_name="Reve Image Edit", + category="api node/image/Reve", + description="Edit images using natural language instructions with Reve.", + inputs=[ + IO.Image.Input("image", tooltip="The image to edit."), + IO.String.Input( + "edit_instruction", + multiline=True, + default="", + tooltip="Text description of how to edit the image. Maximum 2560 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-edit@20250915", "reve-edit-fast@20251030"], + aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for editing.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model"], + ), + expr=""" + ( + $isFast := $contains(widgets.model, "fast"); + $base := $isFast ? 0.01001 : 0.0572; + {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + edit_instruction: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(edit_instruction, min_length=1, max_length=2560) + tts = model["test_time_scaling"] + ar = model["aspect_ratio"] + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/edit", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageEditRequest( + edit_instruction=edit_instruction, + reference_image=tensor_to_base64_string(image), + aspect_ratio=ar if ar != "auto" else None, + version=model["model"], + test_time_scaling=tts if tts and tts > 1 else None, + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveImageRemixNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageRemixNode", + display_name="Reve Image Remix", + category="api node/image/Reve", + description="Combine reference images with text prompts to create new images using Reve.", + inputs=[ + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="image_", + min=1, + max=6, + ), + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. " + "May include XML img tags to reference specific images by index, " + "e.g. 0, 1, etc.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-remix@20250915", "reve-remix-fast@20251030"], + aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for remixing.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model"], + ), + expr=""" + ( + $isFast := $contains(widgets.model, "fast"); + $base := $isFast ? 0.01001 : 0.0572; + {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + reference_images: IO.Autogrow.Type, + prompt: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2560) + if not reference_images: + raise ValueError("At least one reference image is required.") + ref_base64_list = [] + for key in reference_images: + ref_base64_list.append(tensor_to_base64_string(reference_images[key])) + if len(ref_base64_list) > 6: + raise ValueError("Maximum 6 reference images are allowed.") + tts = model["test_time_scaling"] + ar = model["aspect_ratio"] + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/remix", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageRemixRequest( + prompt=prompt, + reference_images=ref_base64_list, + aspect_ratio=ar if ar != "auto" else None, + version=model["model"], + test_time_scaling=tts if tts and tts > 1 else None, + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ReveImageCreateNode, + ReveImageEditNode, + ReveImageRemixNode, + ] + + +async def comfy_entrypoint() -> ReveExtension: + return ReveExtension() diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 79ffb77c1..9d730b81a 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -67,6 +67,7 @@ class _RequestConfig: progress_origin_ts: float | None = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None is_rate_limited: Callable[[int, Any], bool] | None = None + response_header_validator: Callable[[dict[str, str]], None] | None = None @dataclass @@ -202,11 +203,13 @@ async def sync_op_raw( monitor_progress: bool = True, max_retries_on_rate_limit: int = 16, is_rate_limited: Callable[[int, Any], bool] | None = None, + response_header_validator: Callable[[dict[str, str]], None] | None = None, ) -> dict[str, Any] | bytes: """ Make a single network request. - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). - If as_binary=True: returns bytes. + - response_header_validator: optional callback receiving response headers dict """ if isinstance(data, BaseModel): data = data.model_dump(exclude_none=True) @@ -232,6 +235,7 @@ async def sync_op_raw( price_extractor=price_extractor, max_retries_on_rate_limit=max_retries_on_rate_limit, is_rate_limited=is_rate_limited, + response_header_validator=response_header_validator, ) return await _request_base(cfg, expect_binary=as_binary) @@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total ) bytes_payload = bytes(buff) + resp_headers = {k.lower(): v for k, v in resp.headers.items()} + if cfg.price_extractor: + with contextlib.suppress(Exception): + extracted_price = cfg.price_extractor(resp_headers) + if cfg.response_header_validator: + cfg.response_header_validator(resp_headers) operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) request_logger.log_request_response( @@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): request_method=method, request_url=url, response_status_code=resp.status, - response_headers=dict(resp.headers), + response_headers=resp_headers, response_content=bytes_payload, ) return bytes_payload From 4f4f8659c205069f74da8ac47378a5b1c0e142ca Mon Sep 17 00:00:00 2001 From: Adi Borochov <58855640+adiborochov@users.noreply.github.com> Date: Wed, 11 Mar 2026 19:04:13 +0200 Subject: [PATCH 71/75] fix: guard torch.AcceleratorError for compatibility with torch < 2.8.0 (#12874) * fix: guard torch.AcceleratorError for compatibility with torch < 2.8.0 torch.AcceleratorError was introduced in PyTorch 2.8.0. Accessing it directly raises AttributeError on older versions. Use a try/except fallback at module load time, consistent with the existing pattern used for OOM_EXCEPTION. * fix: address review feedback for AcceleratorError compat - Fall back to RuntimeError instead of type(None) for ACCELERATOR_ERROR, consistent with OOM_EXCEPTION fallback pattern and valid for except clauses - Add "out of memory" message introspection for RuntimeError fallback case - Use RuntimeError directly in discard_cuda_async_error except clause --------- --- comfy/model_management.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 81550c790..81c89b180 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -270,10 +270,15 @@ try: except: OOM_EXCEPTION = Exception +try: + ACCELERATOR_ERROR = torch.AcceleratorError +except AttributeError: + ACCELERATOR_ERROR = RuntimeError + def is_oom(e): if isinstance(e, OOM_EXCEPTION): return True - if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2: + if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()): discard_cuda_async_error() return True return False @@ -1275,7 +1280,7 @@ def discard_cuda_async_error(): b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) _ = a + b synchronize() - except torch.AcceleratorError: + except RuntimeError: #Dump it! We already know about it from the synchronous return pass From f6274c06b4e7bce8adbc1c60ae5a4c168825a614 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 11 Mar 2026 13:37:31 -0700 Subject: [PATCH 72/75] Fix issue with batch_size > 1 on some models. (#12892) --- comfy/ldm/flux/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index e20d498f8..e28d704b4 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor * m_mult else: for d in modulation_dims: - tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] + tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1] if m_add is not None: - tensor[:, d[0]:d[1]] += m_add[:, d[2]] + tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1] return tensor From abc87d36693b007bdbdab5ee753ccea6326acb34 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Thu, 12 Mar 2026 06:04:51 +0900 Subject: [PATCH 73/75] Bump comfyui-frontend-package to 1.41.15 (#12891) --------- Co-authored-by: Alexander Brown --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 89cd994e9..ffa5fa376 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.39.19 +comfyui-frontend-package==1.41.15 comfyui-workflow-templates==0.9.18 comfyui-embedded-docs==0.4.3 torch From 9ce4c3dd87c9c77dfe0371045fa920ce55e08973 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Thu, 12 Mar 2026 10:16:30 +0900 Subject: [PATCH 74/75] Bump comfyui-frontend-package to 1.41.16 (#12894) Co-authored-by: github-actions[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ffa5fa376..2272d121a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.41.15 +comfyui-frontend-package==1.41.16 comfyui-workflow-templates==0.9.18 comfyui-embedded-docs==0.4.3 torch From c5e7b9cdaf9b04aba65f6282f4c953748b9a77b1 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Thu, 12 Mar 2026 01:13:43 -0500 Subject: [PATCH 75/75] feat(isolation): process isolation for custom nodes via pyisolate Adds opt-in process isolation for custom nodes using pyisolate's bwrap sandbox and JSON-RPC bridge. Each isolated node pack runs in its own child process with zero-copy tensor transfer via shared memory. Core infrastructure: - CLI flag --use-process-isolation to enable isolation - Host/child startup fencing via PYISOLATE_CHILD env var - Manifest-driven node discovery and extension loading - JSON-RPC bridge between host and child processes - Shared memory forensics for leak detection Proxy layer: - ModelPatcher, CLIP, VAE, and ModelSampling proxies - Host service proxies (folder_paths, model_management, progress, etc.) - Proxy base with automatic method forwarding Execution integration: - Extension wrapper with V3 hidden param mapping - Runtime helpers for isolated node execution - Host policy for node isolation decisions - Fenced sampler device handling and model ejection parity Serializers for cross-process data transfer: - File3D (GLB), PLY (structured + gaussian), NPZ (streaming frames), VIDEO (VideoFromFile + VideoFromComponents) serializers - data_type flag in SerializerRegistry for type-aware dispatch - Isolated get_temp_directory() fence New core save nodes: - SavePLY and SaveNPZ with comfytype registrations (Ply, Npz) DynamicVRAM compatibility: - comfy-aimdo early init gated by isolation fence Tests: - Integration and policy tests for isolation lifecycle - Manifest loader, host policy, proxy, and adapter unit tests Depends on: pyisolate >= 0.9.2 --- .gitignore | 1 + comfy/cli_args.py | 2 + comfy/hooks.py | 36 +- comfy/isolation/__init__.py | 394 ++++++ comfy/isolation/adapter.py | 641 +++++++++ comfy/isolation/child_hooks.py | 141 ++ comfy/isolation/clip_proxy.py | 327 +++++ comfy/isolation/extension_loader.py | 341 +++++ comfy/isolation/extension_wrapper.py | 680 +++++++++ comfy/isolation/host_hooks.py | 26 + comfy/isolation/host_policy.py | 83 ++ comfy/isolation/manifest_loader.py | 186 +++ comfy/isolation/model_patcher_proxy.py | 861 ++++++++++++ .../isolation/model_patcher_proxy_registry.py | 1230 +++++++++++++++++ comfy/isolation/model_patcher_proxy_utils.py | 156 +++ comfy/isolation/model_sampling_proxy.py | 360 +++++ comfy/isolation/proxies/__init__.py | 17 + comfy/isolation/proxies/base.py | 283 ++++ comfy/isolation/proxies/folder_paths_proxy.py | 29 + comfy/isolation/proxies/helper_proxies.py | 98 ++ .../proxies/model_management_proxy.py | 27 + comfy/isolation/proxies/progress_proxy.py | 35 + comfy/isolation/proxies/prompt_server_impl.py | 265 ++++ comfy/isolation/proxies/utils_proxy.py | 64 + comfy/isolation/rpc_bridge.py | 49 + comfy/isolation/runtime_helpers.py | 343 +++++ comfy/isolation/shm_forensics.py | 217 +++ comfy/isolation/vae_proxy.py | 214 +++ comfy/k_diffusion/sampling.py | 10 +- comfy/model_base.py | 15 +- comfy/model_management.py | 115 +- comfy/samplers.py | 64 +- comfy_api/latest/_io.py | 14 +- comfy_api/latest/_util/__init__.py | 4 + comfy_api/latest/_util/npz_types.py | 27 + comfy_api/latest/_util/ply_types.py | 97 ++ comfy_extras/nodes_save_npz.py | 40 + comfy_extras/nodes_save_ply.py | 34 + cuda_malloc.py | 2 +- execution.py | 145 +- main.py | 107 +- nodes.py | 44 +- pyproject.toml | 19 + requirements.txt | 2 + server.py | 3 +- tests/isolation/test_client_snapshot.py | 122 ++ .../test_cuda_wheels_and_env_flags.py | 302 ++++ tests/isolation/test_folder_paths_proxy.py | 111 ++ tests/isolation/test_host_policy.py | 72 + tests/isolation/test_init.py | 56 + tests/isolation/test_manifest_loader_cache.py | 434 ++++++ .../isolation/test_model_management_proxy.py | 50 + tests/isolation/test_path_helpers.py | 93 ++ tests/test_adapter.py | 51 + 54 files changed, 9061 insertions(+), 78 deletions(-) create mode 100644 comfy/isolation/__init__.py create mode 100644 comfy/isolation/adapter.py create mode 100644 comfy/isolation/child_hooks.py create mode 100644 comfy/isolation/clip_proxy.py create mode 100644 comfy/isolation/extension_loader.py create mode 100644 comfy/isolation/extension_wrapper.py create mode 100644 comfy/isolation/host_hooks.py create mode 100644 comfy/isolation/host_policy.py create mode 100644 comfy/isolation/manifest_loader.py create mode 100644 comfy/isolation/model_patcher_proxy.py create mode 100644 comfy/isolation/model_patcher_proxy_registry.py create mode 100644 comfy/isolation/model_patcher_proxy_utils.py create mode 100644 comfy/isolation/model_sampling_proxy.py create mode 100644 comfy/isolation/proxies/__init__.py create mode 100644 comfy/isolation/proxies/base.py create mode 100644 comfy/isolation/proxies/folder_paths_proxy.py create mode 100644 comfy/isolation/proxies/helper_proxies.py create mode 100644 comfy/isolation/proxies/model_management_proxy.py create mode 100644 comfy/isolation/proxies/progress_proxy.py create mode 100644 comfy/isolation/proxies/prompt_server_impl.py create mode 100644 comfy/isolation/proxies/utils_proxy.py create mode 100644 comfy/isolation/rpc_bridge.py create mode 100644 comfy/isolation/runtime_helpers.py create mode 100644 comfy/isolation/shm_forensics.py create mode 100644 comfy/isolation/vae_proxy.py create mode 100644 comfy_api/latest/_util/npz_types.py create mode 100644 comfy_api/latest/_util/ply_types.py create mode 100644 comfy_extras/nodes_save_npz.py create mode 100644 comfy_extras/nodes_save_ply.py create mode 100644 tests/isolation/test_client_snapshot.py create mode 100644 tests/isolation/test_cuda_wheels_and_env_flags.py create mode 100644 tests/isolation/test_folder_paths_proxy.py create mode 100644 tests/isolation/test_host_policy.py create mode 100644 tests/isolation/test_init.py create mode 100644 tests/isolation/test_manifest_loader_cache.py create mode 100644 tests/isolation/test_model_management_proxy.py create mode 100644 tests/isolation/test_path_helpers.py create mode 100644 tests/test_adapter.py diff --git a/.gitignore b/.gitignore index 2700ad5c2..f893b5f14 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ web_custom_versions/ openapi.yaml filtered-openapi.yaml uv.lock +.pyisolate_venvs/ diff --git a/comfy/cli_args.py b/comfy/cli_args.py index e9832acaf..d09736042 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -179,6 +179,8 @@ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable lo parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") +parser.add_argument("--use-process-isolation", action="store_true", help="Enable process isolation for custom nodes with pyisolate.yaml manifests.") + parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") diff --git a/comfy/hooks.py b/comfy/hooks.py index 1a76c7ba4..7a5f69ca7 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -14,6 +14,9 @@ if TYPE_CHECKING: import comfy.lora import comfy.model_management import comfy.patcher_extension +from comfy.cli_args import args +import uuid +import os from node_helpers import conditioning_set_values # ####################################################################################################### @@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum): HookedOnly = "hooked_only" +_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + + class _HookRef: - pass + def __init__(self): + if _ISOLATION_HOOKREF_MODE: + self._pyisolate_id = str(uuid.uuid4()) + + def _ensure_pyisolate_id(self): + pyisolate_id = getattr(self, "_pyisolate_id", None) + if pyisolate_id is None: + pyisolate_id = str(uuid.uuid4()) + self._pyisolate_id = pyisolate_id + return pyisolate_id + + def __eq__(self, other): + if not _ISOLATION_HOOKREF_MODE: + return self is other + if not isinstance(other, _HookRef): + return False + return self._ensure_pyisolate_id() == other._ensure_pyisolate_id() + + def __hash__(self): + if not _ISOLATION_HOOKREF_MODE: + return id(self) + return hash(self._ensure_pyisolate_id()) + + def __str__(self): + if not _ISOLATION_HOOKREF_MODE: + return super().__str__() + return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}" def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -168,6 +200,8 @@ class WeightHook(Hook): key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) else: key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if self.weights is None: + self.weights = {} weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) else: if target == EnumWeightTarget.Clip: diff --git a/comfy/isolation/__init__.py b/comfy/isolation/__init__.py new file mode 100644 index 000000000..34ccb34dc --- /dev/null +++ b/comfy/isolation/__init__.py @@ -0,0 +1,394 @@ +# pylint: disable=consider-using-from-import,cyclic-import,global-statement,global-variable-not-assigned,import-outside-toplevel,logging-fstring-interpolation +from __future__ import annotations +import asyncio +import inspect +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Set, TYPE_CHECKING +import folder_paths +from .extension_loader import load_isolated_node +from .manifest_loader import find_manifest_directories +from .runtime_helpers import build_stub_class, get_class_types_for_extension +from .shm_forensics import scan_shm_forensics, start_shm_forensics + +if TYPE_CHECKING: + from pyisolate import ExtensionManager + from .extension_wrapper import ComfyNodeExtension + +LOG_PREFIX = "][" +isolated_node_timings: List[tuple[float, Path, int]] = [] + +PYISOLATE_VENV_ROOT = Path(folder_paths.base_path) / ".pyisolate_venvs" +PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True) + +logger = logging.getLogger(__name__) +_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 +_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000 + + +def initialize_proxies() -> None: + from .child_hooks import is_child_process + + is_child = is_child_process() + + if is_child: + from .child_hooks import initialize_child_process + + initialize_child_process() + else: + from .host_hooks import initialize_host_process + + initialize_host_process() + start_shm_forensics() + + +@dataclass(frozen=True) +class IsolatedNodeSpec: + node_name: str + display_name: str + stub_class: type + module_path: Path + + +_ISOLATED_NODE_SPECS: List[IsolatedNodeSpec] = [] +_CLAIMED_PATHS: Set[Path] = set() +_ISOLATION_SCAN_ATTEMPTED = False +_EXTENSION_MANAGERS: List["ExtensionManager"] = [] +_RUNNING_EXTENSIONS: Dict[str, "ComfyNodeExtension"] = {} +_ISOLATION_BACKGROUND_TASK: Optional["asyncio.Task[List[IsolatedNodeSpec]]"] = None +_EARLY_START_TIME: Optional[float] = None + + +def start_isolation_loading_early(loop: "asyncio.AbstractEventLoop") -> None: + global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME + if _ISOLATION_BACKGROUND_TASK is not None: + return + _EARLY_START_TIME = time.perf_counter() + _ISOLATION_BACKGROUND_TASK = loop.create_task(initialize_isolation_nodes()) + + +async def await_isolation_loading() -> List[IsolatedNodeSpec]: + global _ISOLATION_BACKGROUND_TASK, _EARLY_START_TIME + if _ISOLATION_BACKGROUND_TASK is not None: + specs = await _ISOLATION_BACKGROUND_TASK + return specs + return await initialize_isolation_nodes() + + +async def initialize_isolation_nodes() -> List[IsolatedNodeSpec]: + global _ISOLATED_NODE_SPECS, _ISOLATION_SCAN_ATTEMPTED, _CLAIMED_PATHS + + if _ISOLATED_NODE_SPECS: + return _ISOLATED_NODE_SPECS + + if _ISOLATION_SCAN_ATTEMPTED: + return [] + + _ISOLATION_SCAN_ATTEMPTED = True + manifest_entries = find_manifest_directories() + _CLAIMED_PATHS = {entry[0].resolve() for entry in manifest_entries} + + if not manifest_entries: + return [] + + os.environ["PYISOLATE_ISOLATION_ACTIVE"] = "1" + concurrency_limit = max(1, (os.cpu_count() or 4) // 2) + semaphore = asyncio.Semaphore(concurrency_limit) + + async def load_with_semaphore( + node_dir: Path, manifest: Path + ) -> List[IsolatedNodeSpec]: + async with semaphore: + load_start = time.perf_counter() + spec_list = await load_isolated_node( + node_dir, + manifest, + logger, + lambda name, info, extension: build_stub_class( + name, + info, + extension, + _RUNNING_EXTENSIONS, + logger, + ), + PYISOLATE_VENV_ROOT, + _EXTENSION_MANAGERS, + ) + spec_list = [ + IsolatedNodeSpec( + node_name=node_name, + display_name=display_name, + stub_class=stub_cls, + module_path=node_dir, + ) + for node_name, display_name, stub_cls in spec_list + ] + isolated_node_timings.append( + (time.perf_counter() - load_start, node_dir, len(spec_list)) + ) + return spec_list + + tasks = [ + load_with_semaphore(node_dir, manifest) + for node_dir, manifest in manifest_entries + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + specs: List[IsolatedNodeSpec] = [] + for result in results: + if isinstance(result, Exception): + logger.error( + "%s Isolated node failed during startup; continuing: %s", + LOG_PREFIX, + result, + ) + continue + specs.extend(result) + + _ISOLATED_NODE_SPECS = specs + return list(_ISOLATED_NODE_SPECS) + + +def _get_class_types_for_extension(extension_name: str) -> Set[str]: + """Get all node class types (node names) belonging to an extension.""" + extension = _RUNNING_EXTENSIONS.get(extension_name) + if not extension: + return set() + + ext_path = Path(extension.module_path) + class_types = set() + for spec in _ISOLATED_NODE_SPECS: + if spec.module_path.resolve() == ext_path.resolve(): + class_types.add(spec.node_name) + + return class_types + + +async def notify_execution_graph(needed_class_types: Set[str]) -> None: + """Evict running extensions not needed for current execution.""" + await wait_for_model_patcher_quiescence( + timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, + fail_loud=True, + marker="ISO:notify_graph_wait_idle", + ) + + async def _stop_extension( + ext_name: str, extension: "ComfyNodeExtension", reason: str + ) -> None: + logger.info("%s ISO:eject_start ext=%s reason=%s", LOG_PREFIX, ext_name, reason) + logger.debug("%s ISO:stop_start ext=%s", LOG_PREFIX, ext_name) + stop_result = extension.stop() + if inspect.isawaitable(stop_result): + await stop_result + _RUNNING_EXTENSIONS.pop(ext_name, None) + logger.debug("%s ISO:stop_done ext=%s", LOG_PREFIX, ext_name) + scan_shm_forensics("ISO:stop_extension", refresh_model_context=True) + + scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True) + isolated_class_types_in_graph = needed_class_types.intersection( + {spec.node_name for spec in _ISOLATED_NODE_SPECS} + ) + graph_uses_isolation = bool(isolated_class_types_in_graph) + logger.debug( + "%s ISO:notify_graph_start running=%d needed=%d", + LOG_PREFIX, + len(_RUNNING_EXTENSIONS), + len(needed_class_types), + ) + if graph_uses_isolation: + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + ext_class_types = _get_class_types_for_extension(ext_name) + + # If NONE of this extension's nodes are in the execution graph -> evict. + if not ext_class_types.intersection(needed_class_types): + await _stop_extension( + ext_name, + extension, + "isolated custom_node not in execution graph, evicting", + ) + else: + logger.debug( + "%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph", + LOG_PREFIX, + len(_RUNNING_EXTENSIONS), + ) + + # Isolated child processes add steady VRAM pressure; reclaim host-side models + # at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom. + try: + import comfy.model_management as model_management + + device = model_management.get_torch_device() + if getattr(device, "type", None) == "cuda": + required = max( + model_management.minimum_inference_memory(), + _WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES, + ) + free_before = model_management.get_free_memory(device) + if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation: + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + await _stop_extension( + ext_name, + extension, + f"boundary low-vram restart (free={int(free_before)} target={int(required)})", + ) + if model_management.get_free_memory(device) < required: + model_management.unload_all_models() + model_management.cleanup_models_gc() + model_management.cleanup_models() + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.soft_empty_cache() + except Exception: + logger.debug( + "%s workflow-boundary host VRAM relief failed", LOG_PREFIX, exc_info=True + ) + finally: + scan_shm_forensics("ISO:notify_graph_done", refresh_model_context=True) + logger.debug( + "%s ISO:notify_graph_done running=%d", LOG_PREFIX, len(_RUNNING_EXTENSIONS) + ) + + +async def flush_running_extensions_transport_state() -> int: + await wait_for_model_patcher_quiescence( + timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS, + fail_loud=True, + marker="ISO:flush_transport_wait_idle", + ) + total_flushed = 0 + for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): + flush_fn = getattr(extension, "flush_transport_state", None) + if not callable(flush_fn): + continue + try: + flushed = await flush_fn() + if isinstance(flushed, int): + total_flushed += flushed + if flushed > 0: + logger.debug( + "%s %s workflow-end flush released=%d", + LOG_PREFIX, + ext_name, + flushed, + ) + except Exception: + logger.debug( + "%s %s workflow-end flush failed", LOG_PREFIX, ext_name, exc_info=True + ) + scan_shm_forensics( + "ISO:flush_running_extensions_transport_state", refresh_model_context=True + ) + return total_flushed + + +async def wait_for_model_patcher_quiescence( + timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS, + *, + fail_loud: bool = False, + marker: str = "ISO:wait_model_patcher_idle", +) -> bool: + try: + from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry + + registry = ModelPatcherRegistry() + start = time.perf_counter() + idle = await registry.wait_all_idle(timeout_ms) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + if idle: + logger.debug( + "%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f", + LOG_PREFIX, + marker, + timeout_ms, + elapsed_ms, + ) + return True + + states = await registry.get_all_operation_states() + logger.error( + "%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s", + LOG_PREFIX, + marker, + timeout_ms, + elapsed_ms, + states, + ) + if fail_loud: + raise TimeoutError( + f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms" + ) + return False + except Exception: + if fail_loud: + raise + logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True) + return False + + +def get_claimed_paths() -> Set[Path]: + return _CLAIMED_PATHS + + +def update_rpc_event_loops(loop: "asyncio.AbstractEventLoop | None" = None) -> None: + """Update all active RPC instances with the current event loop. + + This MUST be called at the start of each workflow execution to ensure + RPC calls are scheduled on the correct event loop. This handles the case + where asyncio.run() creates a new event loop for each workflow. + + Args: + loop: The event loop to use. If None, uses asyncio.get_running_loop(). + """ + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.get_event_loop() + + update_count = 0 + + # Update RPCs from ExtensionManagers + for manager in _EXTENSION_MANAGERS: + if not hasattr(manager, "extensions"): + continue + for name, extension in manager.extensions.items(): + if hasattr(extension, "rpc") and extension.rpc is not None: + if hasattr(extension.rpc, "update_event_loop"): + extension.rpc.update_event_loop(loop) + update_count += 1 + logger.debug(f"{LOG_PREFIX}Updated loop on extension '{name}'") + + # Also update RPCs from running extensions (they may have direct RPC refs) + for name, extension in _RUNNING_EXTENSIONS.items(): + if hasattr(extension, "rpc") and extension.rpc is not None: + if hasattr(extension.rpc, "update_event_loop"): + extension.rpc.update_event_loop(loop) + update_count += 1 + logger.debug(f"{LOG_PREFIX}Updated loop on running extension '{name}'") + + if update_count > 0: + logger.debug(f"{LOG_PREFIX}Updated event loop on {update_count} RPC instances") + else: + logger.debug( + f"{LOG_PREFIX}No RPC instances found to update (managers={len(_EXTENSION_MANAGERS)}, running={len(_RUNNING_EXTENSIONS)})" + ) + + +__all__ = [ + "LOG_PREFIX", + "initialize_proxies", + "initialize_isolation_nodes", + "start_isolation_loading_early", + "await_isolation_loading", + "notify_execution_graph", + "flush_running_extensions_transport_state", + "wait_for_model_patcher_quiescence", + "get_claimed_paths", + "update_rpc_event_loops", + "IsolatedNodeSpec", + "get_class_types_for_extension", +] diff --git a/comfy/isolation/adapter.py b/comfy/isolation/adapter.py new file mode 100644 index 000000000..99beaa191 --- /dev/null +++ b/comfy/isolation/adapter.py @@ -0,0 +1,641 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped] +from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped] + +try: + from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry + from comfy.isolation.model_patcher_proxy import ( + ModelPatcherProxy, + ModelPatcherRegistry, + ) + from comfy.isolation.model_sampling_proxy import ( + ModelSamplingProxy, + ModelSamplingRegistry, + ) + from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + from comfy.isolation.proxies.prompt_server_impl import PromptServerService + from comfy.isolation.proxies.utils_proxy import UtilsProxy + from comfy.isolation.proxies.progress_proxy import ProgressProxy +except ImportError as exc: # Fail loud if Comfy environment is incomplete + raise ImportError(f"ComfyUI environment incomplete: {exc}") + +logger = logging.getLogger(__name__) + +# Force /dev/shm for shared memory (bwrap makes /tmp private) +import tempfile + +if os.path.exists("/dev/shm"): + # Only override if not already set or if default is not /dev/shm + current_tmp = tempfile.gettempdir() + if not current_tmp.startswith("/dev/shm"): + logger.debug( + f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm" + ) + os.environ["TMPDIR"] = "/dev/shm" + tempfile.tempdir = None # Clear cache to force re-evaluation + + +class ComfyUIAdapter(IsolationAdapter): + # ComfyUI-specific IsolationAdapter implementation + + @property + def identifier(self) -> str: + return "comfyui" + + def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]: + if "ComfyUI" in module_path and "custom_nodes" in module_path: + parts = module_path.split("ComfyUI") + if len(parts) > 1: + comfy_root = parts[0] + "ComfyUI" + return { + "preferred_root": comfy_root, + "additional_paths": [ + os.path.join(comfy_root, "custom_nodes"), + os.path.join(comfy_root, "comfy"), + ], + } + return None + + def setup_child_environment(self, snapshot: Dict[str, Any]) -> None: + comfy_root = snapshot.get("preferred_root") + if not comfy_root: + return + + requirements_path = Path(comfy_root) / "requirements.txt" + if requirements_path.exists(): + import re + + for line in requirements_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + pkg_name = re.split(r"[<>=!~\[]", line)[0].strip() + if pkg_name: + logging.getLogger(pkg_name).setLevel(logging.ERROR) + + def register_serializers(self, registry: SerializerRegistryProtocol) -> None: + import torch + + def serialize_device(obj: Any) -> Dict[str, Any]: + return {"__type__": "device", "device_str": str(obj)} + + def deserialize_device(data: Dict[str, Any]) -> Any: + return torch.device(data["device_str"]) + + registry.register("device", serialize_device, deserialize_device) + + _VALID_DTYPES = { + "float16", "float32", "float64", "bfloat16", + "int8", "int16", "int32", "int64", + "uint8", "bool", + } + + def serialize_dtype(obj: Any) -> Dict[str, Any]: + return {"__type__": "dtype", "dtype_str": str(obj)} + + def deserialize_dtype(data: Dict[str, Any]) -> Any: + dtype_name = data["dtype_str"].replace("torch.", "") + if dtype_name not in _VALID_DTYPES: + raise ValueError(f"Invalid dtype: {data['dtype_str']}") + return getattr(torch, dtype_name) + + registry.register("dtype", serialize_dtype, deserialize_dtype) + + def serialize_model_patcher(obj: Any) -> Dict[str, Any]: + # Child-side: must already have _instance_id (proxy) + if os.environ.get("PYISOLATE_CHILD") == "1": + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} + raise RuntimeError( + f"ModelPatcher in child lacks _instance_id: " + f"{type(obj).__module__}.{type(obj).__name__}" + ) + # Host-side: register with registry + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} + model_id = ModelPatcherRegistry().register(obj) + return {"__type__": "ModelPatcherRef", "model_id": model_id} + + def deserialize_model_patcher(data: Any) -> Any: + """Deserialize ModelPatcher refs; pass through already-materialized objects.""" + if isinstance(data, dict): + return ModelPatcherProxy( + data["model_id"], registry=None, manage_lifecycle=False + ) + return data + + def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any: + """Context-aware ModelPatcherRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return ModelPatcherProxy( + data["model_id"], registry=None, manage_lifecycle=False + ) + else: + return ModelPatcherRegistry()._get_instance(data["model_id"]) + + # Register ModelPatcher type for serialization + registry.register( + "ModelPatcher", serialize_model_patcher, deserialize_model_patcher + ) + # Register ModelPatcherProxy type (already a proxy, just return ref) + registry.register( + "ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher + ) + # Register ModelPatcherRef for deserialization (context-aware: host or child) + registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref) + + def serialize_clip(obj: Any) -> Dict[str, Any]: + if hasattr(obj, "_instance_id"): + return {"__type__": "CLIPRef", "clip_id": obj._instance_id} + clip_id = CLIPRegistry().register(obj) + return {"__type__": "CLIPRef", "clip_id": clip_id} + + def deserialize_clip(data: Any) -> Any: + if isinstance(data, dict): + return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) + return data + + def deserialize_clip_ref(data: Dict[str, Any]) -> Any: + """Context-aware CLIPRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) + else: + return CLIPRegistry()._get_instance(data["clip_id"]) + + # Register CLIP type for serialization + registry.register("CLIP", serialize_clip, deserialize_clip) + # Register CLIPProxy type (already a proxy, just return ref) + registry.register("CLIPProxy", serialize_clip, deserialize_clip) + # Register CLIPRef for deserialization (context-aware: host or child) + registry.register("CLIPRef", None, deserialize_clip_ref) + + def serialize_vae(obj: Any) -> Dict[str, Any]: + if hasattr(obj, "_instance_id"): + return {"__type__": "VAERef", "vae_id": obj._instance_id} + vae_id = VAERegistry().register(obj) + return {"__type__": "VAERef", "vae_id": vae_id} + + def deserialize_vae(data: Any) -> Any: + if isinstance(data, dict): + return VAEProxy(data["vae_id"]) + return data + + def deserialize_vae_ref(data: Dict[str, Any]) -> Any: + """Context-aware VAERef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + # Child: create a proxy + return VAEProxy(data["vae_id"]) + else: + # Host: lookup real VAE from registry + return VAERegistry()._get_instance(data["vae_id"]) + + # Register VAE type for serialization + registry.register("VAE", serialize_vae, deserialize_vae) + # Register VAEProxy type (already a proxy, just return ref) + registry.register("VAEProxy", serialize_vae, deserialize_vae) + # Register VAERef for deserialization (context-aware: host or child) + registry.register("VAERef", None, deserialize_vae_ref) + + # ModelSampling serialization - handles ModelSampling* types + # copyreg removed - no pickle fallback allowed + + def serialize_model_sampling(obj: Any) -> Dict[str, Any]: + # Child-side: must already have _instance_id (proxy) + if os.environ.get("PYISOLATE_CHILD") == "1": + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} + raise RuntimeError( + f"ModelSampling in child lacks _instance_id: " + f"{type(obj).__module__}.{type(obj).__name__}" + ) + # Host-side pass-through for proxies: do not re-register a proxy as a + # new ModelSamplingRef, or we create proxy-of-proxy indirection. + if hasattr(obj, "_instance_id"): + return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} + # Host-side: register with ModelSamplingRegistry and return JSON-safe dict + ms_id = ModelSamplingRegistry().register(obj) + return {"__type__": "ModelSamplingRef", "ms_id": ms_id} + + def deserialize_model_sampling(data: Any) -> Any: + """Deserialize ModelSampling refs; pass through already-materialized objects.""" + if isinstance(data, dict): + return ModelSamplingProxy(data["ms_id"]) + return data + + def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any: + """Context-aware ModelSamplingRef deserializer for both host and child.""" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + if is_child: + return ModelSamplingProxy(data["ms_id"]) + else: + return ModelSamplingRegistry()._get_instance(data["ms_id"]) + + # Register all ModelSampling* and StableCascadeSampling classes dynamically + import comfy.model_sampling + + for ms_cls in vars(comfy.model_sampling).values(): + if not isinstance(ms_cls, type): + continue + if not issubclass(ms_cls, torch.nn.Module): + continue + if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"): + continue + registry.register( + ms_cls.__name__, + serialize_model_sampling, + deserialize_model_sampling, + ) + registry.register( + "ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling + ) + # Register ModelSamplingRef for deserialization (context-aware: host or child) + registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref) + + def serialize_cond(obj: Any) -> Dict[str, Any]: + type_key = f"{type(obj).__module__}.{type(obj).__name__}" + return { + "__type__": type_key, + "cond": obj.cond, + } + + def deserialize_cond(data: Dict[str, Any]) -> Any: + import importlib + + type_key = data["__type__"] + module_name, class_name = type_key.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + return cls(data["cond"]) + + def _serialize_public_state(obj: Any) -> Dict[str, Any]: + state: Dict[str, Any] = {} + for key, value in obj.__dict__.items(): + if key.startswith("_"): + continue + if callable(value): + continue + state[key] = value + return state + + def serialize_latent_format(obj: Any) -> Dict[str, Any]: + type_key = f"{type(obj).__module__}.{type(obj).__name__}" + return { + "__type__": type_key, + "state": _serialize_public_state(obj), + } + + def deserialize_latent_format(data: Dict[str, Any]) -> Any: + import importlib + + type_key = data["__type__"] + module_name, class_name = type_key.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + obj = cls() + for key, value in data.get("state", {}).items(): + prop = getattr(type(obj), key, None) + if isinstance(prop, property) and prop.fset is None: + continue + setattr(obj, key, value) + return obj + + import comfy.conds + + for cond_cls in vars(comfy.conds).values(): + if not isinstance(cond_cls, type): + continue + if not issubclass(cond_cls, comfy.conds.CONDRegular): + continue + type_key = f"{cond_cls.__module__}.{cond_cls.__name__}" + registry.register(type_key, serialize_cond, deserialize_cond) + registry.register(cond_cls.__name__, serialize_cond, deserialize_cond) + + import comfy.latent_formats + + for latent_cls in vars(comfy.latent_formats).values(): + if not isinstance(latent_cls, type): + continue + if not issubclass(latent_cls, comfy.latent_formats.LatentFormat): + continue + type_key = f"{latent_cls.__module__}.{latent_cls.__name__}" + registry.register( + type_key, serialize_latent_format, deserialize_latent_format + ) + registry.register( + latent_cls.__name__, serialize_latent_format, deserialize_latent_format + ) + + # V3 API: unwrap NodeOutput.args + def deserialize_node_output(data: Any) -> Any: + return getattr(data, "args", data) + + registry.register("NodeOutput", None, deserialize_node_output) + + # KSAMPLER serializer: stores sampler name instead of function object + # sampler_function is a callable which gets filtered out by JSONSocketTransport + def serialize_ksampler(obj: Any) -> Dict[str, Any]: + func_name = obj.sampler_function.__name__ + # Map function name back to sampler name + if func_name == "sample_unipc": + sampler_name = "uni_pc" + elif func_name == "sample_unipc_bh2": + sampler_name = "uni_pc_bh2" + elif func_name == "dpm_fast_function": + sampler_name = "dpm_fast" + elif func_name == "dpm_adaptive_function": + sampler_name = "dpm_adaptive" + elif func_name.startswith("sample_"): + sampler_name = func_name[7:] # Remove "sample_" prefix + else: + sampler_name = func_name + return { + "__type__": "KSAMPLER", + "sampler_name": sampler_name, + "extra_options": obj.extra_options, + "inpaint_options": obj.inpaint_options, + } + + def deserialize_ksampler(data: Dict[str, Any]) -> Any: + import comfy.samplers + + return comfy.samplers.ksampler( + data["sampler_name"], + data.get("extra_options", {}), + data.get("inpaint_options", {}), + ) + + registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler) + + from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers + + register_hooks_serializers(registry) + + # Generic Numpy Serializer + def serialize_numpy(obj: Any) -> Any: + import torch + + try: + # Attempt zero-copy conversion to Tensor + return torch.from_numpy(obj) + except Exception: + # Fallback for non-numeric arrays (strings, objects, mixes) + return obj.tolist() + + registry.register("ndarray", serialize_numpy, None) + + def serialize_ply(obj: Any) -> Dict[str, Any]: + import base64 + import torch + if obj.raw_data is not None: + return { + "__type__": "PLY", + "raw_data": base64.b64encode(obj.raw_data).decode("ascii"), + } + result: Dict[str, Any] = {"__type__": "PLY", "points": torch.from_numpy(obj.points)} + if obj.colors is not None: + result["colors"] = torch.from_numpy(obj.colors) + if obj.confidence is not None: + result["confidence"] = torch.from_numpy(obj.confidence) + if obj.view_id is not None: + result["view_id"] = torch.from_numpy(obj.view_id) + return result + + def deserialize_ply(data: Any) -> Any: + import base64 + from comfy_api.latest._util.ply_types import PLY + if "raw_data" in data: + return PLY(raw_data=base64.b64decode(data["raw_data"])) + return PLY( + points=data["points"], + colors=data.get("colors"), + confidence=data.get("confidence"), + view_id=data.get("view_id"), + ) + + registry.register("PLY", serialize_ply, deserialize_ply, data_type=True) + + def serialize_npz(obj: Any) -> Dict[str, Any]: + import base64 + return { + "__type__": "NPZ", + "frames": [base64.b64encode(f).decode("ascii") for f in obj.frames], + } + + def deserialize_npz(data: Any) -> Any: + import base64 + from comfy_api.latest._util.npz_types import NPZ + return NPZ(frames=[base64.b64decode(f) for f in data["frames"]]) + + registry.register("NPZ", serialize_npz, deserialize_npz, data_type=True) + + def serialize_file3d(obj: Any) -> Dict[str, Any]: + import base64 + return { + "__type__": "File3D", + "format": obj.format, + "data": base64.b64encode(obj.get_bytes()).decode("ascii"), + } + + def deserialize_file3d(data: Any) -> Any: + import base64 + from io import BytesIO + from comfy_api.latest._util.geometry_types import File3D + return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"]) + + registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True) + + def serialize_video(obj: Any) -> Dict[str, Any]: + components = obj.get_components() + images = components.images.detach() if components.images.requires_grad else components.images + result: Dict[str, Any] = { + "__type__": "VIDEO", + "images": images, + "frame_rate_num": components.frame_rate.numerator, + "frame_rate_den": components.frame_rate.denominator, + } + if components.audio is not None: + waveform = components.audio["waveform"] + if waveform.requires_grad: + waveform = waveform.detach() + result["audio_waveform"] = waveform + result["audio_sample_rate"] = components.audio["sample_rate"] + if components.metadata is not None: + result["metadata"] = components.metadata + return result + + def deserialize_video(data: Any) -> Any: + from fractions import Fraction + from comfy_api.latest._input_impl.video_types import VideoFromComponents + from comfy_api.latest._util.video_types import VideoComponents + audio = None + if "audio_waveform" in data: + audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]} + components = VideoComponents( + images=data["images"], + frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]), + audio=audio, + metadata=data.get("metadata"), + ) + return VideoFromComponents(components) + + registry.register("VIDEO", serialize_video, deserialize_video, data_type=True) + registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True) + registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True) + + def provide_rpc_services(self) -> List[type[ProxiedSingleton]]: + return [ + PromptServerService, + FolderPathsProxy, + ModelManagementProxy, + UtilsProxy, + ProgressProxy, + VAERegistry, + CLIPRegistry, + ModelPatcherRegistry, + ModelSamplingRegistry, + FirstStageModelRegistry, + ] + + def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None: + # Resolve the real name whether it's an instance or the Singleton class itself + api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__ + + if api_name == "FolderPathsProxy": + import folder_paths + + # Replace module-level functions with proxy methods + # This is aggressive but necessary for transparent proxying + # Handle both instance and class cases + instance = api() if isinstance(api, type) else api + for name in dir(instance): + if not name.startswith("_"): + setattr(folder_paths, name, getattr(instance, name)) + + # Fence: isolated children get writable temp inside sandbox + if os.environ.get("PYISOLATE_CHILD") == "1": + _child_temp = os.path.join("/tmp", "comfyui_temp") + os.makedirs(_child_temp, exist_ok=True) + folder_paths.temp_directory = _child_temp + + return + + if api_name == "ModelManagementProxy": + import comfy.model_management + + instance = api() if isinstance(api, type) else api + # Replace module-level functions with proxy methods + for name in dir(instance): + if not name.startswith("_"): + setattr(comfy.model_management, name, getattr(instance, name)) + return + + if api_name == "UtilsProxy": + import comfy.utils + + # Static Injection of RPC mechanism to ensure Child can access it + # independent of instance lifecycle. + api.set_rpc(rpc) + + # Don't overwrite host hook (infinite recursion) + return + + if api_name == "PromptServerProxy": + # Defer heavy import to child context + import server + + instance = api() if isinstance(api, type) else api + proxy = ( + instance.instance + ) # PromptServerProxy instance has .instance property returning self + + original_register_route = proxy.register_route + + def register_route_wrapper( + method: str, path: str, handler: Callable[..., Any] + ) -> None: + callback_id = rpc.register_callback(handler) + loop = getattr(rpc, "loop", None) + if loop and loop.is_running(): + import asyncio + + asyncio.create_task( + original_register_route( + method, path, handler=callback_id, is_callback=True + ) + ) + else: + original_register_route( + method, path, handler=callback_id, is_callback=True + ) + return None + + proxy.register_route = register_route_wrapper + + class RouteTableDefProxy: + def __init__(self, proxy_instance: Any): + self.proxy = proxy_instance + + def get( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("GET", path, handler) + return handler + + return decorator + + def post( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("POST", path, handler) + return handler + + return decorator + + def patch( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("PATCH", path, handler) + return handler + + return decorator + + def put( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("PUT", path, handler) + return handler + + return decorator + + def delete( + self, path: str, **kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + self.proxy.register_route("DELETE", path, handler) + return handler + + return decorator + + proxy.routes = RouteTableDefProxy(proxy) + + if ( + hasattr(server, "PromptServer") + and getattr(server.PromptServer, "instance", None) != proxy + ): + server.PromptServer.instance = proxy diff --git a/comfy/isolation/child_hooks.py b/comfy/isolation/child_hooks.py new file mode 100644 index 000000000..a1ba201ac --- /dev/null +++ b/comfy/isolation/child_hooks.py @@ -0,0 +1,141 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation +# Child process initialization for PyIsolate +import logging +import os + +logger = logging.getLogger(__name__) + + +def is_child_process() -> bool: + return os.environ.get("PYISOLATE_CHILD") == "1" + + +def initialize_child_process() -> None: + # Manual RPC injection + try: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc: + _setup_prompt_server_stub(rpc) + _setup_utils_proxy(rpc) + else: + logger.warning("Could not get child RPC instance for manual injection") + _setup_prompt_server_stub() + _setup_utils_proxy() + except Exception as e: + logger.error(f"Manual RPC Injection failed: {e}") + _setup_prompt_server_stub() + _setup_utils_proxy() + + _setup_logging() + + +def _setup_prompt_server_stub(rpc=None) -> None: + try: + from .proxies.prompt_server_impl import PromptServerStub + import sys + import types + + # Mock server module + if "server" not in sys.modules: + mock_server = types.ModuleType("server") + sys.modules["server"] = mock_server + + server = sys.modules["server"] + + if not hasattr(server, "PromptServer"): + + class MockPromptServer: + pass + + server.PromptServer = MockPromptServer + + stub = PromptServerStub() + + if rpc: + PromptServerStub.set_rpc(rpc) + if hasattr(stub, "set_rpc"): + stub.set_rpc(rpc) + + server.PromptServer.instance = stub + + except Exception as e: + logger.error(f"Failed to setup PromptServerStub: {e}") + + +def _setup_utils_proxy(rpc=None) -> None: + try: + import comfy.utils + import asyncio + + # Capture main loop during initialization (safe context) + main_loop = None + try: + main_loop = asyncio.get_running_loop() + except RuntimeError: + try: + main_loop = asyncio.get_event_loop() + except RuntimeError: + pass + + try: + from .proxies.base import set_global_loop + + if main_loop: + set_global_loop(main_loop) + except ImportError: + pass + + # Sync hook wrapper for progress updates + def sync_hook_wrapper( + value: int, total: int, preview: None = None, node_id: None = None + ) -> None: + if node_id is None: + try: + from comfy_execution.utils import get_executing_context + + ctx = get_executing_context() + if ctx: + node_id = ctx.node_id + else: + pass + except Exception: + pass + + # Bypass blocked event loop by direct outbox injection + if rpc: + try: + # Use captured main loop if available (for threaded execution), or current loop + loop = main_loop + if loop is None: + loop = asyncio.get_event_loop() + + rpc.outbox.put( + { + "kind": "call", + "object_id": "UtilsProxy", + "parent_call_id": None, # We are root here usually + "calling_loop": loop, + "future": loop.create_future(), # Dummy future + "method": "progress_bar_hook", + "args": (value, total, preview, node_id), + "kwargs": {}, + } + ) + + except Exception as e: + logging.getLogger(__name__).error(f"Manual Inject Failed: {e}") + else: + logging.getLogger(__name__).warning( + "No RPC instance available for progress update" + ) + + comfy.utils.PROGRESS_BAR_HOOK = sync_hook_wrapper + + except Exception as e: + logger.error(f"Failed to setup UtilsProxy hook: {e}") + + +def _setup_logging() -> None: + logging.getLogger().setLevel(logging.INFO) diff --git a/comfy/isolation/clip_proxy.py b/comfy/isolation/clip_proxy.py new file mode 100644 index 000000000..371665314 --- /dev/null +++ b/comfy/isolation/clip_proxy.py @@ -0,0 +1,327 @@ +# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation +# CLIP Proxy implementation +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Optional + +from comfy.isolation.proxies.base import ( + IS_CHILD_PROCESS, + BaseProxy, + BaseRegistry, + detach_if_grad, +) + +if TYPE_CHECKING: + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + +class CondStageModelRegistry(BaseRegistry[Any]): + _type_prefix = "cond_stage_model" + + async def get_property(self, instance_id: str, name: str) -> Any: + obj = self._get_instance(instance_id) + return getattr(obj, name) + + +class CondStageModelProxy(BaseProxy[CondStageModelRegistry]): + _registry_class = CondStageModelRegistry + __module__ = "comfy.sd" + + def __getattr__(self, name: str) -> Any: + try: + return self._call_rpc("get_property", name) + except Exception as e: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) from e + + def __repr__(self) -> str: + return f"" + + +class TokenizerRegistry(BaseRegistry[Any]): + _type_prefix = "tokenizer" + + async def get_property(self, instance_id: str, name: str) -> Any: + obj = self._get_instance(instance_id) + return getattr(obj, name) + + +class TokenizerProxy(BaseProxy[TokenizerRegistry]): + _registry_class = TokenizerRegistry + __module__ = "comfy.sd" + + def __getattr__(self, name: str) -> Any: + try: + return self._call_rpc("get_property", name) + except Exception as e: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) from e + + def __repr__(self) -> str: + return f"" + + +logger = logging.getLogger(__name__) + + +class CLIPRegistry(BaseRegistry[Any]): + _type_prefix = "clip" + _allowed_setters = { + "layer_idx", + "tokenizer_options", + "use_clip_schedule", + "apply_hooks_to_conds", + } + + async def get_ram_usage(self, instance_id: str) -> int: + return self._get_instance(instance_id).get_ram_usage() + + async def get_patcher_id(self, instance_id: str) -> str: + from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry + + return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher) + + async def get_cond_stage_model_id(self, instance_id: str) -> str: + return CondStageModelRegistry().register( + self._get_instance(instance_id).cond_stage_model + ) + + async def get_tokenizer_id(self, instance_id: str) -> str: + return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer) + + async def load_model(self, instance_id: str) -> None: + self._get_instance(instance_id).load_model() + + async def clip_layer(self, instance_id: str, layer_idx: int) -> None: + self._get_instance(instance_id).clip_layer(layer_idx) + + async def set_tokenizer_option( + self, instance_id: str, option_name: str, value: Any + ) -> None: + self._get_instance(instance_id).set_tokenizer_option(option_name, value) + + async def get_property(self, instance_id: str, name: str) -> Any: + return getattr(self._get_instance(instance_id), name) + + async def set_property(self, instance_id: str, name: str, value: Any) -> None: + if name not in self._allowed_setters: + raise PermissionError(f"Setting '{name}' is not allowed via RPC") + setattr(self._get_instance(instance_id), name, value) + + async def tokenize( + self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any + ) -> Any: + return self._get_instance(instance_id).tokenize( + text, return_word_ids=return_word_ids, **kwargs + ) + + async def encode(self, instance_id: str, text: str) -> Any: + return detach_if_grad(self._get_instance(instance_id).encode(text)) + + async def encode_from_tokens( + self, + instance_id: str, + tokens: Any, + return_pooled: bool = False, + return_dict: bool = False, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).encode_from_tokens( + tokens, return_pooled=return_pooled, return_dict=return_dict + ) + ) + + async def encode_from_tokens_scheduled( + self, + instance_id: str, + tokens: Any, + unprojected: bool = False, + add_dict: Optional[dict] = None, + show_pbar: bool = True, + ) -> Any: + add_dict = add_dict or {} + return detach_if_grad( + self._get_instance(instance_id).encode_from_tokens_scheduled( + tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar + ) + ) + + async def add_patches( + self, + instance_id: str, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> Any: + return self._get_instance(instance_id).add_patches( + patches, strength_patch=strength_patch, strength_model=strength_model + ) + + async def get_key_patches(self, instance_id: str) -> Any: + return self._get_instance(instance_id).get_key_patches() + + async def load_sd( + self, instance_id: str, sd: dict, full_model: bool = False + ) -> Any: + return self._get_instance(instance_id).load_sd(sd, full_model=full_model) + + async def get_sd(self, instance_id: str) -> Any: + return self._get_instance(instance_id).get_sd() + + async def clone(self, instance_id: str) -> str: + return self.register(self._get_instance(instance_id).clone()) + + +class CLIPProxy(BaseProxy[CLIPRegistry]): + _registry_class = CLIPRegistry + __module__ = "comfy.sd" + + def get_ram_usage(self) -> int: + return self._call_rpc("get_ram_usage") + + @property + def patcher(self) -> "ModelPatcherProxy": + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + if not hasattr(self, "_patcher_proxy"): + patcher_id = self._call_rpc("get_patcher_id") + self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False) + return self._patcher_proxy + + @patcher.setter + def patcher(self, value: Any) -> None: + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + if isinstance(value, ModelPatcherProxy): + self._patcher_proxy = value + else: + logger.warning( + f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}" + ) + + @property + def cond_stage_model(self) -> CondStageModelProxy: + if not hasattr(self, "_cond_stage_model_proxy"): + csm_id = self._call_rpc("get_cond_stage_model_id") + self._cond_stage_model_proxy = CondStageModelProxy( + csm_id, manage_lifecycle=False + ) + return self._cond_stage_model_proxy + + @property + def tokenizer(self) -> TokenizerProxy: + if not hasattr(self, "_tokenizer_proxy"): + tok_id = self._call_rpc("get_tokenizer_id") + self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False) + return self._tokenizer_proxy + + def load_model(self) -> ModelPatcherProxy: + self._call_rpc("load_model") + return self.patcher + + @property + def layer_idx(self) -> Optional[int]: + return self._call_rpc("get_property", "layer_idx") + + @layer_idx.setter + def layer_idx(self, value: Optional[int]) -> None: + self._call_rpc("set_property", "layer_idx", value) + + @property + def tokenizer_options(self) -> dict: + return self._call_rpc("get_property", "tokenizer_options") + + @tokenizer_options.setter + def tokenizer_options(self, value: dict) -> None: + self._call_rpc("set_property", "tokenizer_options", value) + + @property + def use_clip_schedule(self) -> bool: + return self._call_rpc("get_property", "use_clip_schedule") + + @use_clip_schedule.setter + def use_clip_schedule(self, value: bool) -> None: + self._call_rpc("set_property", "use_clip_schedule", value) + + @property + def apply_hooks_to_conds(self) -> Any: + return self._call_rpc("get_property", "apply_hooks_to_conds") + + @apply_hooks_to_conds.setter + def apply_hooks_to_conds(self, value: Any) -> None: + self._call_rpc("set_property", "apply_hooks_to_conds", value) + + def clip_layer(self, layer_idx: int) -> None: + return self._call_rpc("clip_layer", layer_idx) + + def set_tokenizer_option(self, option_name: str, value: Any) -> None: + return self._call_rpc("set_tokenizer_option", option_name, value) + + def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any: + return self._call_rpc( + "tokenize", text, return_word_ids=return_word_ids, **kwargs + ) + + def encode(self, text: str) -> Any: + return self._call_rpc("encode", text) + + def encode_from_tokens( + self, tokens: Any, return_pooled: bool = False, return_dict: bool = False + ) -> Any: + res = self._call_rpc( + "encode_from_tokens", + tokens, + return_pooled=return_pooled, + return_dict=return_dict, + ) + if return_pooled and isinstance(res, list) and not return_dict: + return tuple(res) + return res + + def encode_from_tokens_scheduled( + self, + tokens: Any, + unprojected: bool = False, + add_dict: Optional[dict] = None, + show_pbar: bool = True, + ) -> Any: + add_dict = add_dict or {} + return self._call_rpc( + "encode_from_tokens_scheduled", + tokens, + unprojected=unprojected, + add_dict=add_dict, + show_pbar=show_pbar, + ) + + def add_patches( + self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0 + ) -> Any: + return self._call_rpc( + "add_patches", + patches, + strength_patch=strength_patch, + strength_model=strength_model, + ) + + def get_key_patches(self) -> Any: + return self._call_rpc("get_key_patches") + + def load_sd(self, sd: dict, full_model: bool = False) -> Any: + return self._call_rpc("load_sd", sd, full_model=full_model) + + def get_sd(self) -> Any: + return self._call_rpc("get_sd") + + def clone(self) -> CLIPProxy: + new_id = self._call_rpc("clone") + return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS) + + +if not IS_CHILD_PROCESS: + _CLIP_REGISTRY_SINGLETON = CLIPRegistry() + _COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry() + _TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry() diff --git a/comfy/isolation/extension_loader.py b/comfy/isolation/extension_loader.py new file mode 100644 index 000000000..632b08857 --- /dev/null +++ b/comfy/isolation/extension_loader.py @@ -0,0 +1,341 @@ +# pylint: disable=cyclic-import,import-outside-toplevel,redefined-outer-name +from __future__ import annotations + +import logging +import os +import inspect +import sys +import types +import platform +from pathlib import Path +from typing import Callable, Dict, List, Tuple + +import pyisolate +from pyisolate import ExtensionManager, ExtensionManagerConfig +from packaging.requirements import InvalidRequirement, Requirement +from packaging.utils import canonicalize_name + +from .extension_wrapper import ComfyNodeExtension +from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache +from .host_policy import load_host_policy + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +logger = logging.getLogger(__name__) + + +async def _stop_extension_safe( + extension: ComfyNodeExtension, extension_name: str +) -> None: + try: + stop_result = extension.stop() + if inspect.isawaitable(stop_result): + await stop_result + except Exception: + logger.debug("][ %s stop failed", extension_name, exc_info=True) + + +def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str: + req, sep, marker = dep.partition(";") + req = req.strip() + marker_suffix = f";{marker}" if sep else "" + + def _resolve_local_path(local_path: str) -> Path | None: + for base in base_paths: + candidate = (base / local_path).resolve() + if candidate.exists(): + return candidate + return None + + if req.startswith("./") or req.startswith("../"): + resolved = _resolve_local_path(req) + if resolved is not None: + return f"{resolved}{marker_suffix}" + + if req.startswith("file://"): + raw = req[len("file://") :] + if raw.startswith("./") or raw.startswith("../"): + resolved = _resolve_local_path(raw) + if resolved is not None: + return f"file://{resolved}{marker_suffix}" + + return dep + + +def _dependency_name_from_spec(dep: str) -> str | None: + stripped = dep.strip() + if not stripped or stripped == "-e" or stripped.startswith("-e "): + return None + if stripped.startswith(("/", "./", "../", "file://")): + return None + + try: + return canonicalize_name(Requirement(stripped).name) + except InvalidRequirement: + return None + + +def _parse_cuda_wheels_config( + tool_config: dict[str, object], dependencies: list[str] +) -> dict[str, object] | None: + raw_config = tool_config.get("cuda_wheels") + if raw_config is None: + return None + if not isinstance(raw_config, dict): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels] must be a table" + ) + + index_url = raw_config.get("index_url") + if not isinstance(index_url, str) or not index_url.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string" + ) + + packages = raw_config.get("packages") + if not isinstance(packages, list) or not all( + isinstance(package_name, str) and package_name.strip() + for package_name in packages + ): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings" + ) + + declared_dependencies = { + dependency_name + for dep in dependencies + if (dependency_name := _dependency_name_from_spec(dep)) is not None + } + normalized_packages = [canonicalize_name(package_name) for package_name in packages] + missing = [ + package_name + for package_name in normalized_packages + if package_name not in declared_dependencies + ] + if missing: + missing_joined = ", ".join(sorted(missing)) + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: " + f"{missing_joined}" + ) + + package_map = raw_config.get("package_map", {}) + if not isinstance(package_map, dict): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] must be a table" + ) + + normalized_package_map: dict[str, str] = {} + for dependency_name, index_package_name in package_map.items(): + if not isinstance(dependency_name, str) or not dependency_name.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings" + ) + if not isinstance(index_package_name, str) or not index_package_name.strip(): + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings" + ) + canonical_dependency_name = canonicalize_name(dependency_name) + if canonical_dependency_name not in normalized_packages: + raise ExtensionLoadError( + "[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in " + "[tool.comfy.isolation.cuda_wheels.packages]" + ) + normalized_package_map[canonical_dependency_name] = index_package_name.strip() + + return { + "index_url": index_url.rstrip("/") + "/", + "packages": normalized_packages, + "package_map": normalized_package_map, + } + + +def get_enforcement_policy() -> Dict[str, bool]: + return { + "force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1", + "force_sandbox": os.environ.get("PYISOLATE_ENFORCE_SANDBOX") == "1", + } + + +class ExtensionLoadError(RuntimeError): + pass + + +def register_dummy_module(extension_name: str, node_dir: Path) -> None: + normalized_name = extension_name.replace("-", "_").replace(".", "_") + if normalized_name not in sys.modules: + dummy_module = types.ModuleType(normalized_name) + dummy_module.__file__ = str(node_dir / "__init__.py") + dummy_module.__path__ = [str(node_dir)] + dummy_module.__package__ = normalized_name + sys.modules[normalized_name] = dummy_module + + +def _is_stale_node_cache(cached_data: Dict[str, Dict]) -> bool: + for details in cached_data.values(): + if not isinstance(details, dict): + return True + if details.get("is_v3") and "schema_v1" not in details: + return True + return False + + +async def load_isolated_node( + node_dir: Path, + manifest_path: Path, + logger: logging.Logger, + build_stub_class: Callable[[str, Dict[str, object], ComfyNodeExtension], type], + venv_root: Path, + extension_managers: List[ExtensionManager], +) -> List[Tuple[str, str, type]]: + try: + with manifest_path.open("rb") as handle: + manifest_data = tomllib.load(handle) + except Exception as e: + logger.warning(f"][ Failed to parse {manifest_path}: {e}") + return [] + + # Parse [tool.comfy.isolation] + tool_config = manifest_data.get("tool", {}).get("comfy", {}).get("isolation", {}) + can_isolate = tool_config.get("can_isolate", False) + share_torch = tool_config.get("share_torch", False) + + # Parse [project] dependencies + project_config = manifest_data.get("project", {}) + dependencies = project_config.get("dependencies", []) + if not isinstance(dependencies, list): + dependencies = [] + + # Get extension name (default to folder name if not in project.name) + extension_name = project_config.get("name", node_dir.name) + + # LOGIC: Isolation Decision + policy = get_enforcement_policy() + isolated = can_isolate or policy["force_isolated"] + + if not isolated: + return [] + + logger.info(f"][ Loading isolated node: {extension_name}") + + import folder_paths + + base_paths = [Path(folder_paths.base_path), node_dir] + dependencies = [ + _normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep + for dep in dependencies + ] + cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies) + + manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root)) + manager: ExtensionManager = pyisolate.ExtensionManager( + ComfyNodeExtension, manager_config + ) + extension_managers.append(manager) + + host_policy = load_host_policy(Path(folder_paths.base_path)) + + sandbox_config = {} + is_linux = platform.system() == "Linux" + if is_linux and isolated: + sandbox_config = { + "network": host_policy["allow_network"], + "writable_paths": host_policy["writable_paths"], + "readonly_paths": host_policy["readonly_paths"], + } + share_cuda_ipc = share_torch and is_linux + + extension_config = { + "name": extension_name, + "module_path": str(node_dir), + "isolated": True, + "dependencies": dependencies, + "share_torch": share_torch, + "share_cuda_ipc": share_cuda_ipc, + "sandbox": sandbox_config, + } + if cuda_wheels is not None: + extension_config["cuda_wheels"] = cuda_wheels + + extension = manager.load_extension(extension_config) + register_dummy_module(extension_name, node_dir) + + # Try cache first (lazy spawn) + if is_cache_valid(node_dir, manifest_path, venv_root): + cached_data = load_from_cache(node_dir, venv_root) + if cached_data: + if _is_stale_node_cache(cached_data): + logger.debug( + "][ %s cache is stale/incompatible; rebuilding metadata", + extension_name, + ) + else: + logger.debug(f"][ {extension_name} loaded from cache") + specs: List[Tuple[str, str, type]] = [] + for node_name, details in cached_data.items(): + stub_cls = build_stub_class(node_name, details, extension) + specs.append( + (node_name, details.get("display_name", node_name), stub_cls) + ) + return specs + + # Cache miss - spawn process and get metadata + logger.debug(f"][ {extension_name} cache miss, spawning process for metadata") + + try: + remote_nodes: Dict[str, str] = await extension.list_nodes() + except Exception as exc: + logger.warning( + "][ %s metadata discovery failed, skipping isolated load: %s", + extension_name, + exc, + ) + await _stop_extension_safe(extension, extension_name) + return [] + + if not remote_nodes: + logger.debug("][ %s exposed no isolated nodes; skipping", extension_name) + await _stop_extension_safe(extension, extension_name) + return [] + + specs: List[Tuple[str, str, type]] = [] + cache_data: Dict[str, Dict] = {} + + for node_name, display_name in remote_nodes.items(): + try: + details = await extension.get_node_details(node_name) + except Exception as exc: + logger.warning( + "][ %s failed to load metadata for %s, skipping node: %s", + extension_name, + node_name, + exc, + ) + continue + details["display_name"] = display_name + cache_data[node_name] = details + stub_cls = build_stub_class(node_name, details, extension) + specs.append((node_name, display_name, stub_cls)) + + if not specs: + logger.warning( + "][ %s produced no usable nodes after metadata scan; skipping", + extension_name, + ) + await _stop_extension_safe(extension, extension_name) + return [] + + # Save metadata to cache for future runs + save_to_cache(node_dir, venv_root, cache_data, manifest_path) + logger.debug(f"][ {extension_name} metadata cached") + + # EJECT: Kill process after getting metadata (will respawn on first execution) + await _stop_extension_safe(extension, extension_name) + + return specs + + +__all__ = ["ExtensionLoadError", "register_dummy_module", "load_isolated_node"] diff --git a/comfy/isolation/extension_wrapper.py b/comfy/isolation/extension_wrapper.py new file mode 100644 index 000000000..d58acb18e --- /dev/null +++ b/comfy/isolation/extension_wrapper.py @@ -0,0 +1,680 @@ +# pylint: disable=consider-using-from-import,cyclic-import,import-outside-toplevel,logging-fstring-interpolation,protected-access,wrong-import-position +from __future__ import annotations + +import asyncio +import torch + + +class AttrDict(dict): + def __getattr__(self, item): + try: + return self[item] + except KeyError as e: + raise AttributeError(item) from e + + def copy(self): + return AttrDict(super().copy()) + + +import importlib +import inspect +import json +import logging +import os +import sys +import uuid +from dataclasses import asdict +from typing import Any, Dict, List, Tuple + +from pyisolate import ExtensionBase + +from comfy_api.internal import _ComfyNodeInternal + +LOG_PREFIX = "][" +V3_DISCOVERY_TIMEOUT = 30 +_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + +logger = logging.getLogger(__name__) + + +def _flush_tensor_transport_state(marker: str) -> int: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + return 0 + if not callable(flush_tensor_keeper): + return 0 + flushed = flush_tensor_keeper() + if flushed > 0: + logger.debug( + "%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed + ) + return flushed + + +def _relieve_child_vram_pressure(marker: str) -> None: + import comfy.model_management as model_management + + model_management.cleanup_models_gc() + model_management.cleanup_models() + + device = model_management.get_torch_device() + if not hasattr(device, "type") or device.type == "cpu": + return + + required = max( + model_management.minimum_inference_memory(), + _PRE_EXEC_MIN_FREE_VRAM_BYTES, + ) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=True) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.cleanup_models() + model_management.soft_empty_cache() + logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required) + + +def _sanitize_for_transport(value): + primitives = (str, int, float, bool, type(None)) + if isinstance(value, primitives): + return value + + cls_name = value.__class__.__name__ + if cls_name == "FlexibleOptionalInputType": + return { + "__pyisolate_flexible_optional__": True, + "type": _sanitize_for_transport(getattr(value, "type", "*")), + } + if cls_name == "AnyType": + return {"__pyisolate_any_type__": True, "value": str(value)} + if cls_name == "ByPassTypeTuple": + return { + "__pyisolate_bypass_tuple__": [ + _sanitize_for_transport(v) for v in tuple(value) + ] + } + + if isinstance(value, dict): + return {k: _sanitize_for_transport(v) for k, v in value.items()} + if isinstance(value, tuple): + return {"__pyisolate_tuple__": [_sanitize_for_transport(v) for v in value]} + if isinstance(value, list): + return [_sanitize_for_transport(v) for v in value] + + return str(value) + + +# Re-export RemoteObjectHandle from pyisolate for backward compatibility +# The canonical definition is now in pyisolate._internal.remote_handle +from pyisolate._internal.remote_handle import RemoteObjectHandle # noqa: E402,F401 + + +class ComfyNodeExtension(ExtensionBase): + def __init__(self) -> None: + super().__init__() + self.node_classes: Dict[str, type] = {} + self.display_names: Dict[str, str] = {} + self.node_instances: Dict[str, Any] = {} + self.remote_objects: Dict[str, Any] = {} + self._route_handlers: Dict[str, Any] = {} + self._module: Any = None + + async def on_module_loaded(self, module: Any) -> None: + self._module = module + + # Registries are initialized in host_hooks.py initialize_host_process() + # They auto-register via ProxiedSingleton when instantiated + # NO additional setup required here - if a registry is missing from host_hooks, it WILL fail + + self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {} + self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {} + + try: + from comfy_api.latest import ComfyExtension + + for name, obj in inspect.getmembers(module): + if not ( + inspect.isclass(obj) + and issubclass(obj, ComfyExtension) + and obj is not ComfyExtension + ): + continue + if not obj.__module__.startswith(module.__name__): + continue + try: + ext_instance = obj() + try: + await asyncio.wait_for( + ext_instance.on_load(), timeout=V3_DISCOVERY_TIMEOUT + ) + except asyncio.TimeoutError: + logger.error( + "%s V3 Extension %s timed out in on_load()", + LOG_PREFIX, + name, + ) + continue + try: + v3_nodes = await asyncio.wait_for( + ext_instance.get_node_list(), timeout=V3_DISCOVERY_TIMEOUT + ) + except asyncio.TimeoutError: + logger.error( + "%s V3 Extension %s timed out in get_node_list()", + LOG_PREFIX, + name, + ) + continue + for node_cls in v3_nodes: + if hasattr(node_cls, "GET_SCHEMA"): + schema = node_cls.GET_SCHEMA() + self.node_classes[schema.node_id] = node_cls + if schema.display_name: + self.display_names[schema.node_id] = schema.display_name + except Exception as e: + logger.error("%s V3 Extension %s failed: %s", LOG_PREFIX, name, e) + except ImportError: + pass + + module_name = getattr(module, "__name__", "isolated_nodes") + for node_cls in self.node_classes.values(): + if hasattr(node_cls, "__module__") and "/" in str(node_cls.__module__): + node_cls.__module__ = module_name + + self.node_instances = {} + + async def list_nodes(self) -> Dict[str, str]: + return {name: self.display_names.get(name, name) for name in self.node_classes} + + async def get_node_info(self, node_name: str) -> Dict[str, Any]: + return await self.get_node_details(node_name) + + async def get_node_details(self, node_name: str) -> Dict[str, Any]: + node_cls = self._get_node_class(node_name) + is_v3 = issubclass(node_cls, _ComfyNodeInternal) + + input_types_raw = ( + node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {} + ) + output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None) + if output_is_list is not None: + output_is_list = tuple(bool(x) for x in output_is_list) + + details: Dict[str, Any] = { + "input_types": _sanitize_for_transport(input_types_raw), + "return_types": tuple( + str(t) for t in getattr(node_cls, "RETURN_TYPES", ()) + ), + "return_names": getattr(node_cls, "RETURN_NAMES", None), + "function": str(getattr(node_cls, "FUNCTION", "execute")), + "category": str(getattr(node_cls, "CATEGORY", "")), + "output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)), + "output_is_list": output_is_list, + "is_v3": is_v3, + } + + if is_v3: + try: + schema = node_cls.GET_SCHEMA() + schema_v1 = asdict(schema.get_v1_info(node_cls)) + try: + schema_v3 = asdict(schema.get_v3_info(node_cls)) + except (AttributeError, TypeError): + schema_v3 = self._build_schema_v3_fallback(schema) + details.update( + { + "schema_v1": schema_v1, + "schema_v3": schema_v3, + "hidden": [h.value for h in (schema.hidden or [])], + "description": getattr(schema, "description", ""), + "deprecated": bool(getattr(node_cls, "DEPRECATED", False)), + "experimental": bool(getattr(node_cls, "EXPERIMENTAL", False)), + "api_node": bool(getattr(node_cls, "API_NODE", False)), + "input_is_list": bool( + getattr(node_cls, "INPUT_IS_LIST", False) + ), + "not_idempotent": bool( + getattr(node_cls, "NOT_IDEMPOTENT", False) + ), + } + ) + except Exception as exc: + logger.warning( + "%s V3 schema serialization failed for %s: %s", + LOG_PREFIX, + node_name, + exc, + ) + return details + + def _build_schema_v3_fallback(self, schema) -> Dict[str, Any]: + input_dict: Dict[str, Any] = {} + output_dict: Dict[str, Any] = {} + hidden_list: List[str] = [] + + if getattr(schema, "inputs", None): + for inp in schema.inputs: + self._add_schema_io_v3(inp, input_dict) + if getattr(schema, "outputs", None): + for out in schema.outputs: + self._add_schema_io_v3(out, output_dict) + if getattr(schema, "hidden", None): + for h in schema.hidden: + hidden_list.append(getattr(h, "value", str(h))) + + return { + "input": input_dict, + "output": output_dict, + "hidden": hidden_list, + "name": getattr(schema, "node_id", None), + "display_name": getattr(schema, "display_name", None), + "description": getattr(schema, "description", None), + "category": getattr(schema, "category", None), + "output_node": getattr(schema, "is_output_node", False), + "deprecated": getattr(schema, "is_deprecated", False), + "experimental": getattr(schema, "is_experimental", False), + "api_node": getattr(schema, "is_api_node", False), + } + + def _add_schema_io_v3(self, io_obj: Any, target: Dict[str, Any]) -> None: + io_id = getattr(io_obj, "id", None) + if io_id is None: + return + + io_type_fn = getattr(io_obj, "get_io_type", None) + io_type = ( + io_type_fn() if callable(io_type_fn) else getattr(io_obj, "io_type", None) + ) + + as_dict_fn = getattr(io_obj, "as_dict", None) + payload = as_dict_fn() if callable(as_dict_fn) else {} + + target[str(io_id)] = (io_type, payload) + + async def get_input_types(self, node_name: str) -> Dict[str, Any]: + node_cls = self._get_node_class(node_name) + if hasattr(node_cls, "INPUT_TYPES"): + return node_cls.INPUT_TYPES() + return {} + + async def execute_node(self, node_name: str, **inputs: Any) -> Tuple[Any, ...]: + logger.debug( + "%s ISO:child_execute_start ext=%s node=%s input_keys=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + len(inputs), + ) + if os.environ.get("PYISOLATE_CHILD") == "1": + _relieve_child_vram_pressure("EXT:pre_execute") + + resolved_inputs = self._resolve_remote_objects(inputs) + + instance = self._get_node_instance(node_name) + node_cls = self._get_node_class(node_name) + + # V3 API nodes expect hidden parameters in cls.hidden, not as kwargs + # Hidden params come through RPC as string keys like "Hidden.prompt" + from comfy_api.latest._io import Hidden, HiddenHolder + + # Map string representations back to Hidden enum keys + hidden_string_map = { + "Hidden.unique_id": Hidden.unique_id, + "Hidden.prompt": Hidden.prompt, + "Hidden.extra_pnginfo": Hidden.extra_pnginfo, + "Hidden.dynprompt": Hidden.dynprompt, + "Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org, + "Hidden.api_key_comfy_org": Hidden.api_key_comfy_org, + # Uppercase enum VALUE forms — V3 execution engine passes these + "UNIQUE_ID": Hidden.unique_id, + "PROMPT": Hidden.prompt, + "EXTRA_PNGINFO": Hidden.extra_pnginfo, + "DYNPROMPT": Hidden.dynprompt, + "AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org, + "API_KEY_COMFY_ORG": Hidden.api_key_comfy_org, + } + + # Find and extract hidden parameters (both enum and string form) + hidden_found = {} + keys_to_remove = [] + + for key in list(resolved_inputs.keys()): + # Check string form first (from RPC serialization) + if key in hidden_string_map: + hidden_found[hidden_string_map[key]] = resolved_inputs[key] + keys_to_remove.append(key) + # Also check enum form (direct calls) + elif isinstance(key, Hidden): + hidden_found[key] = resolved_inputs[key] + keys_to_remove.append(key) + + # Remove hidden params from kwargs + for key in keys_to_remove: + resolved_inputs.pop(key) + + # Set hidden on node class if any hidden params found + if hidden_found: + if not hasattr(node_cls, "hidden") or node_cls.hidden is None: + node_cls.hidden = HiddenHolder.from_dict(hidden_found) + else: + # Update existing hidden holder + for key, value in hidden_found.items(): + setattr(node_cls.hidden, key.value.lower(), value) + + function_name = getattr(node_cls, "FUNCTION", "execute") + if not hasattr(instance, function_name): + raise AttributeError(f"Node {node_name} missing callable '{function_name}'") + + handler = getattr(instance, function_name) + + try: + if asyncio.iscoroutinefunction(handler): + result = await handler(**resolved_inputs) + else: + import functools + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + None, functools.partial(handler, **resolved_inputs) + ) + except Exception: + logger.exception( + "%s ISO:child_execute_error ext=%s node=%s", + LOG_PREFIX, + getattr(self, "name", "?"), + node_name, + ) + raise + + if type(result).__name__ == "NodeOutput": + result = result.args + if self._is_comfy_protocol_return(result): + wrapped = self._wrap_unpicklable_objects(result) + return wrapped + + if not isinstance(result, tuple): + result = (result,) + wrapped = self._wrap_unpicklable_objects(result) + return wrapped + + async def flush_transport_state(self) -> int: + if os.environ.get("PYISOLATE_CHILD") != "1": + return 0 + logger.debug( + "%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?") + ) + flushed = _flush_tensor_transport_state("EXT:workflow_end") + try: + from comfy.isolation.model_patcher_proxy_registry import ( + ModelPatcherRegistry, + ) + + registry = ModelPatcherRegistry() + removed = registry.sweep_pending_cleanup() + if removed > 0: + logger.debug( + "%s EXT:workflow_end registry sweep removed=%d", LOG_PREFIX, removed + ) + except Exception: + logger.debug( + "%s EXT:workflow_end registry sweep failed", LOG_PREFIX, exc_info=True + ) + logger.debug( + "%s ISO:child_flush_done ext=%s flushed=%d", + LOG_PREFIX, + getattr(self, "name", "?"), + flushed, + ) + return flushed + + async def get_remote_object(self, object_id: str) -> Any: + """Retrieve a remote object by ID for host-side deserialization.""" + if object_id not in self.remote_objects: + raise KeyError(f"Remote object {object_id} not found") + + return self.remote_objects[object_id] + + def _wrap_unpicklable_objects(self, data: Any) -> Any: + if isinstance(data, (str, int, float, bool, type(None))): + return data + if isinstance(data, torch.Tensor): + tensor = data.detach() if data.requires_grad else data + if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu": + return tensor.cpu() + return tensor + + # Special-case clip vision outputs: preserve attribute access by packing fields + if hasattr(data, "penultimate_hidden_states") or hasattr( + data, "last_hidden_state" + ): + fields = {} + for attr in ( + "penultimate_hidden_states", + "last_hidden_state", + "image_embeds", + "text_embeds", + ): + if hasattr(data, attr): + try: + fields[attr] = self._wrap_unpicklable_objects( + getattr(data, attr) + ) + except Exception: + pass + if fields: + return {"__pyisolate_attribute_container__": True, "data": fields} + + # Avoid converting arbitrary objects with stateful methods (models, etc.) + # They will be handled via RemoteObjectHandle below. + + type_name = type(data).__name__ + if type_name == "ModelPatcherProxy": + return {"__type__": "ModelPatcherRef", "model_id": data._instance_id} + if type_name == "CLIPProxy": + return {"__type__": "CLIPRef", "clip_id": data._instance_id} + if type_name == "VAEProxy": + return {"__type__": "VAERef", "vae_id": data._instance_id} + if type_name == "ModelSamplingProxy": + return {"__type__": "ModelSamplingRef", "ms_id": data._instance_id} + + if isinstance(data, (list, tuple)): + wrapped = [self._wrap_unpicklable_objects(item) for item in data] + return tuple(wrapped) if isinstance(data, tuple) else wrapped + if isinstance(data, dict): + converted_dict = { + k: self._wrap_unpicklable_objects(v) for k, v in data.items() + } + return {"__pyisolate_attrdict__": True, "data": converted_dict} + + from pyisolate._internal.serialization_registry import SerializerRegistry + + registry = SerializerRegistry.get_instance() + if registry.is_data_type(type_name): + serializer = registry.get_serializer(type_name) + if serializer: + return serializer(data) + + object_id = str(uuid.uuid4()) + self.remote_objects[object_id] = data + return RemoteObjectHandle(object_id, type(data).__name__) + + def _resolve_remote_objects(self, data: Any) -> Any: + if isinstance(data, RemoteObjectHandle): + if data.object_id not in self.remote_objects: + raise KeyError(f"Remote object {data.object_id} not found") + return self.remote_objects[data.object_id] + + if isinstance(data, dict): + ref_type = data.get("__type__") + if ref_type in ("CLIPRef", "ModelPatcherRef", "VAERef"): + from pyisolate._internal.model_serialization import ( + deserialize_proxy_result, + ) + + return deserialize_proxy_result(data) + if ref_type == "ModelSamplingRef": + from pyisolate._internal.model_serialization import ( + deserialize_proxy_result, + ) + + return deserialize_proxy_result(data) + return {k: self._resolve_remote_objects(v) for k, v in data.items()} + + if isinstance(data, (list, tuple)): + resolved = [self._resolve_remote_objects(item) for item in data] + return tuple(resolved) if isinstance(data, tuple) else resolved + return data + + def _get_node_class(self, node_name: str) -> type: + if node_name not in self.node_classes: + raise KeyError(f"Unknown node: {node_name}") + return self.node_classes[node_name] + + def _get_node_instance(self, node_name: str) -> Any: + if node_name not in self.node_instances: + if node_name not in self.node_classes: + raise KeyError(f"Unknown node: {node_name}") + self.node_instances[node_name] = self.node_classes[node_name]() + return self.node_instances[node_name] + + async def before_module_loaded(self) -> None: + # Inject initialization here if we think this is the child + try: + from comfy.isolation import initialize_proxies + + initialize_proxies() + except Exception as e: + logging.getLogger(__name__).error( + f"Failed to call initialize_proxies in before_module_loaded: {e}" + ) + + await super().before_module_loaded() + try: + from comfy_api.latest import ComfyAPI_latest + from .proxies.progress_proxy import ProgressProxy + + ComfyAPI_latest.Execution = ProgressProxy + # ComfyAPI_latest.execution = ProgressProxy() # Eliminated to avoid Singleton collision + # fp_proxy = FolderPathsProxy() # Eliminated to avoid Singleton collision + # latest_ui.folder_paths = fp_proxy + # latest_resources.folder_paths = fp_proxy + except Exception: + pass + + async def call_route_handler( + self, + handler_module: str, + handler_func: str, + request_data: Dict[str, Any], + ) -> Any: + cache_key = f"{handler_module}.{handler_func}" + if cache_key not in self._route_handlers: + if self._module is not None and hasattr(self._module, "__file__"): + node_dir = os.path.dirname(self._module.__file__) + if node_dir not in sys.path: + sys.path.insert(0, node_dir) + try: + module = importlib.import_module(handler_module) + self._route_handlers[cache_key] = getattr(module, handler_func) + except (ImportError, AttributeError) as e: + raise ValueError(f"Route handler not found: {cache_key}") from e + + handler = self._route_handlers[cache_key] + mock_request = MockRequest(request_data) + + if asyncio.iscoroutinefunction(handler): + result = await handler(mock_request) + else: + result = handler(mock_request) + return self._serialize_response(result) + + def _is_comfy_protocol_return(self, result: Any) -> bool: + """ + Check if the result matches the ComfyUI 'Protocol Return' schema. + + A Protocol Return is a dictionary containing specific reserved keys that + ComfyUI's execution engine interprets as instructions (UI updates, + Workflow expansion, etc.) rather than purely data outputs. + + Schema: + - Must be a dict + - Must contain at least one of: 'ui', 'result', 'expand' + """ + if not isinstance(result, dict): + return False + return any(key in result for key in ("ui", "result", "expand")) + + def _serialize_response(self, response: Any) -> Dict[str, Any]: + if response is None: + return {"type": "text", "body": "", "status": 204} + if isinstance(response, dict): + return {"type": "json", "body": response, "status": 200} + if isinstance(response, str): + return {"type": "text", "body": response, "status": 200} + if hasattr(response, "text") and hasattr(response, "status"): + return { + "type": "text", + "body": response.text + if hasattr(response, "text") + else str(response.body), + "status": response.status, + "headers": dict(response.headers) + if hasattr(response, "headers") + else {}, + } + if hasattr(response, "body") and hasattr(response, "status"): + body = response.body + if isinstance(body, bytes): + try: + return { + "type": "text", + "body": body.decode("utf-8"), + "status": response.status, + } + except UnicodeDecodeError: + return { + "type": "binary", + "body": body.hex(), + "status": response.status, + } + return {"type": "json", "body": body, "status": response.status} + return {"type": "text", "body": str(response), "status": 200} + + +class MockRequest: + def __init__(self, data: Dict[str, Any]): + self.method = data.get("method", "GET") + self.path = data.get("path", "/") + self.query = data.get("query", {}) + self._body = data.get("body", {}) + self._text = data.get("text", "") + self.headers = data.get("headers", {}) + self.content_type = data.get( + "content_type", self.headers.get("Content-Type", "application/json") + ) + self.match_info = data.get("match_info", {}) + + async def json(self) -> Any: + if isinstance(self._body, dict): + return self._body + if isinstance(self._body, str): + return json.loads(self._body) + return {} + + async def post(self) -> Dict[str, Any]: + if isinstance(self._body, dict): + return self._body + return {} + + async def text(self) -> str: + if self._text: + return self._text + if isinstance(self._body, str): + return self._body + if isinstance(self._body, dict): + return json.dumps(self._body) + return "" + + async def read(self) -> bytes: + return (await self.text()).encode("utf-8") diff --git a/comfy/isolation/host_hooks.py b/comfy/isolation/host_hooks.py new file mode 100644 index 000000000..86cde10a8 --- /dev/null +++ b/comfy/isolation/host_hooks.py @@ -0,0 +1,26 @@ +# pylint: disable=import-outside-toplevel +# Host process initialization for PyIsolate +import logging + +logger = logging.getLogger(__name__) + + +def initialize_host_process() -> None: + root = logging.getLogger() + for handler in root.handlers[:]: + root.removeHandler(handler) + root.addHandler(logging.NullHandler()) + + from .proxies.folder_paths_proxy import FolderPathsProxy + from .proxies.model_management_proxy import ModelManagementProxy + from .proxies.progress_proxy import ProgressProxy + from .proxies.prompt_server_impl import PromptServerService + from .proxies.utils_proxy import UtilsProxy + from .vae_proxy import VAERegistry + + FolderPathsProxy() + ModelManagementProxy() + ProgressProxy() + PromptServerService() + UtilsProxy() + VAERegistry() diff --git a/comfy/isolation/host_policy.py b/comfy/isolation/host_policy.py new file mode 100644 index 000000000..660dcda20 --- /dev/null +++ b/comfy/isolation/host_policy.py @@ -0,0 +1,83 @@ +# pylint: disable=logging-fstring-interpolation +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Dict, List, TypedDict + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +logger = logging.getLogger(__name__) + + +class HostSecurityPolicy(TypedDict): + allow_network: bool + writable_paths: List[str] + readonly_paths: List[str] + whitelist: Dict[str, str] + + +DEFAULT_POLICY: HostSecurityPolicy = { + "allow_network": False, + "writable_paths": ["/dev/shm", "/tmp"], + "readonly_paths": [], + "whitelist": {}, +} + + +def _default_policy() -> HostSecurityPolicy: + return { + "allow_network": DEFAULT_POLICY["allow_network"], + "writable_paths": list(DEFAULT_POLICY["writable_paths"]), + "readonly_paths": list(DEFAULT_POLICY["readonly_paths"]), + "whitelist": dict(DEFAULT_POLICY["whitelist"]), + } + + +def load_host_policy(comfy_root: Path) -> HostSecurityPolicy: + config_path = comfy_root / "pyproject.toml" + policy = _default_policy() + + if not config_path.exists(): + logger.debug("Host policy file missing at %s, using defaults.", config_path) + return policy + + try: + with config_path.open("rb") as f: + data = tomllib.load(f) + except Exception: + logger.warning( + "Failed to parse host policy from %s, using defaults.", + config_path, + exc_info=True, + ) + return policy + + tool_config = data.get("tool", {}).get("comfy", {}).get("host", {}) + if not isinstance(tool_config, dict): + logger.debug("No [tool.comfy.host] section found, using defaults.") + return policy + + if "allow_network" in tool_config: + policy["allow_network"] = bool(tool_config["allow_network"]) + + if "writable_paths" in tool_config: + policy["writable_paths"] = [str(p) for p in tool_config["writable_paths"]] + + if "readonly_paths" in tool_config: + policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]] + + whitelist_raw = tool_config.get("whitelist") + if isinstance(whitelist_raw, dict): + policy["whitelist"] = {str(k): str(v) for k, v in whitelist_raw.items()} + + logger.debug( + f"Loaded Host Policy: {len(policy['whitelist'])} whitelisted nodes, Network={policy['allow_network']}" + ) + return policy + + +__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"] diff --git a/comfy/isolation/manifest_loader.py b/comfy/isolation/manifest_loader.py new file mode 100644 index 000000000..42007302f --- /dev/null +++ b/comfy/isolation/manifest_loader.py @@ -0,0 +1,186 @@ +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +import hashlib +import json +import logging +import os +import sys +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import folder_paths + +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[no-redef] + +LOG_PREFIX = "][" +logger = logging.getLogger(__name__) + +CACHE_SUBDIR = "cache" +CACHE_KEY_FILE = "cache_key" +CACHE_DATA_FILE = "node_info.json" +CACHE_KEY_LENGTH = 16 + + +def find_manifest_directories() -> List[Tuple[Path, Path]]: + """Find custom node directories containing a valid pyproject.toml with [tool.comfy.isolation].""" + manifest_dirs: List[Tuple[Path, Path]] = [] + + # Standard custom_nodes paths + for base_path in folder_paths.get_folder_paths("custom_nodes"): + base = Path(base_path) + if not base.exists() or not base.is_dir(): + continue + + for entry in base.iterdir(): + if not entry.is_dir(): + continue + + # Look for pyproject.toml + manifest = entry / "pyproject.toml" + if not manifest.exists(): + continue + + # Validate [tool.comfy.isolation] section existence + try: + with manifest.open("rb") as f: + data = tomllib.load(f) + + if ( + "tool" in data + and "comfy" in data["tool"] + and "isolation" in data["tool"]["comfy"] + ): + manifest_dirs.append((entry, manifest)) + + except Exception: + continue + + return manifest_dirs + + +def compute_cache_key(node_dir: Path, manifest_path: Path) -> str: + """Hash manifest + .py mtimes + Python version + PyIsolate version.""" + hasher = hashlib.sha256() + + try: + # Hashing the manifest content ensures config changes invalidate cache + hasher.update(manifest_path.read_bytes()) + except OSError: + hasher.update(b"__manifest_read_error__") + + try: + py_files = sorted(node_dir.rglob("*.py")) + for py_file in py_files: + rel_path = py_file.relative_to(node_dir) + if "__pycache__" in str(rel_path) or ".venv" in str(rel_path): + continue + hasher.update(str(rel_path).encode("utf-8")) + try: + hasher.update(str(py_file.stat().st_mtime).encode("utf-8")) + except OSError: + hasher.update(b"__file_stat_error__") + except OSError: + hasher.update(b"__dir_scan_error__") + + hasher.update(sys.version.encode("utf-8")) + + try: + import pyisolate + + hasher.update(pyisolate.__version__.encode("utf-8")) + except (ImportError, AttributeError): + hasher.update(b"__pyisolate_unknown__") + + return hasher.hexdigest()[:CACHE_KEY_LENGTH] + + +def get_cache_path(node_dir: Path, venv_root: Path) -> Tuple[Path, Path]: + """Return (cache_key_file, cache_data_file) in venv_root/{node}/cache/.""" + cache_dir = venv_root / node_dir.name / CACHE_SUBDIR + return (cache_dir / CACHE_KEY_FILE, cache_dir / CACHE_DATA_FILE) + + +def is_cache_valid(node_dir: Path, manifest_path: Path, venv_root: Path) -> bool: + """Return True only if stored cache key matches current computed key.""" + try: + cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root) + if not cache_key_file.exists() or not cache_data_file.exists(): + return False + current_key = compute_cache_key(node_dir, manifest_path) + stored_key = cache_key_file.read_text(encoding="utf-8").strip() + return current_key == stored_key + except Exception as e: + logger.debug( + "%s Cache validation error for %s: %s", LOG_PREFIX, node_dir.name, e + ) + return False + + +def load_from_cache(node_dir: Path, venv_root: Path) -> Optional[Dict[str, Any]]: + """Load node metadata from cache, return None on any error.""" + try: + _, cache_data_file = get_cache_path(node_dir, venv_root) + if not cache_data_file.exists(): + return None + data = json.loads(cache_data_file.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return None + return data + except Exception: + return None + + +def save_to_cache( + node_dir: Path, venv_root: Path, node_data: Dict[str, Any], manifest_path: Path +) -> None: + """Save node metadata and cache key atomically.""" + try: + cache_key_file, cache_data_file = get_cache_path(node_dir, venv_root) + cache_dir = cache_key_file.parent + cache_dir.mkdir(parents=True, exist_ok=True) + cache_key = compute_cache_key(node_dir, manifest_path) + + # Atomic write: data + tmp_data_fd, tmp_data_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp") + try: + with os.fdopen(tmp_data_fd, "w", encoding="utf-8") as f: + json.dump(node_data, f, indent=2) + os.replace(tmp_data_path, cache_data_file) + except Exception: + try: + os.unlink(tmp_data_path) + except OSError: + pass + raise + + # Atomic write: key + tmp_key_fd, tmp_key_path = tempfile.mkstemp(dir=str(cache_dir), suffix=".tmp") + try: + with os.fdopen(tmp_key_fd, "w", encoding="utf-8") as f: + f.write(cache_key) + os.replace(tmp_key_path, cache_key_file) + except Exception: + try: + os.unlink(tmp_key_path) + except OSError: + pass + raise + + except Exception as e: + logger.warning("%s Cache save failed for %s: %s", LOG_PREFIX, node_dir.name, e) + + +__all__ = [ + "LOG_PREFIX", + "find_manifest_directories", + "compute_cache_key", + "get_cache_path", + "is_cache_valid", + "load_from_cache", + "save_to_cache", +] diff --git a/comfy/isolation/model_patcher_proxy.py b/comfy/isolation/model_patcher_proxy.py new file mode 100644 index 000000000..e1c513933 --- /dev/null +++ b/comfy/isolation/model_patcher_proxy.py @@ -0,0 +1,861 @@ +# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access +# RPC proxy for ModelPatcher (parent process) +from __future__ import annotations + +import logging +from typing import Any, Optional, List, Set, Dict, Callable + +from comfy.isolation.proxies.base import ( + IS_CHILD_PROCESS, + BaseProxy, +) +from comfy.isolation.model_patcher_proxy_registry import ( + ModelPatcherRegistry, + AutoPatcherEjector, +) + +logger = logging.getLogger(__name__) + + +class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): + _registry_class = ModelPatcherRegistry + __module__ = "comfy.model_patcher" + _APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024 + + def _get_rpc(self) -> Any: + if self._rpc_caller is None: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc is not None: + self._rpc_caller = rpc.create_caller( + self._registry_class, self._registry_class.get_remote_id() + ) + else: + self._rpc_caller = self._registry + return self._rpc_caller + + def get_all_callbacks(self, call_type: str = None) -> Any: + return self._call_rpc("get_all_callbacks", call_type) + + def get_all_wrappers(self, wrapper_type: str = None) -> Any: + return self._call_rpc("get_all_wrappers", wrapper_type) + + def _load_list(self, *args, **kwargs) -> Any: + return self._call_rpc("load_list_internal", *args, **kwargs) + + def prepare_hook_patches_current_keyframe( + self, t: Any, hook_group: Any, model_options: Any + ) -> None: + self._call_rpc( + "prepare_hook_patches_current_keyframe", t, hook_group, model_options + ) + + def add_hook_patches( + self, + hook: Any, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> None: + self._call_rpc( + "add_hook_patches", hook, patches, strength_patch, strength_model + ) + + def clear_cached_hook_weights(self) -> None: + self._call_rpc("clear_cached_hook_weights") + + def get_combined_hook_patches(self, hooks: Any) -> Any: + return self._call_rpc("get_combined_hook_patches", hooks) + + def get_additional_models_with_key(self, key: str) -> Any: + return self._call_rpc("get_additional_models_with_key", key) + + @property + def object_patches(self) -> Any: + return self._call_rpc("get_object_patches") + + @property + def patches(self) -> Any: + res = self._call_rpc("get_patches") + if isinstance(res, dict): + new_res = {} + for k, v in res.items(): + new_list = [] + for item in v: + if isinstance(item, list): + new_list.append(tuple(item)) + else: + new_list.append(item) + new_res[k] = new_list + return new_res + return res + + @property + def pinned(self) -> Set: + val = self._call_rpc("get_patcher_attr", "pinned") + return set(val) if val is not None else set() + + @property + def hook_patches(self) -> Dict: + val = self._call_rpc("get_patcher_attr", "hook_patches") + if val is None: + return {} + try: + from comfy.hooks import _HookRef + import json + + new_val = {} + for k, v in val.items(): + if isinstance(k, str): + if k.startswith("PYISOLATE_HOOKREF:"): + ref_id = k.split(":", 1)[1] + h = _HookRef() + h._pyisolate_id = ref_id + new_val[h] = v + elif k.startswith("__pyisolate_key__"): + try: + json_str = k[len("__pyisolate_key__") :] + data = json.loads(json_str) + ref_id = None + if isinstance(data, list): + for item in data: + if ( + isinstance(item, list) + and len(item) == 2 + and item[0] == "id" + ): + ref_id = item[1] + break + if ref_id: + h = _HookRef() + h._pyisolate_id = ref_id + new_val[h] = v + else: + new_val[k] = v + except Exception: + new_val[k] = v + else: + new_val[k] = v + else: + new_val[k] = v + return new_val + except ImportError: + return val + + def set_hook_mode(self, hook_mode: Any) -> None: + self._call_rpc("set_hook_mode", hook_mode) + + def register_all_hook_patches( + self, + hooks: Any, + target_dict: Any, + model_options: Any = None, + registered: Any = None, + ) -> None: + self._call_rpc( + "register_all_hook_patches", hooks, target_dict, model_options, registered + ) + + def is_clone(self, other: Any) -> bool: + if isinstance(other, ModelPatcherProxy): + return self._call_rpc("is_clone_by_id", other._instance_id) + return False + + def clone(self) -> ModelPatcherProxy: + new_id = self._call_rpc("clone") + return ModelPatcherProxy( + new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS + ) + + def clone_has_same_weights(self, clone: Any) -> bool: + if isinstance(clone, ModelPatcherProxy): + return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id) + if not IS_CHILD_PROCESS: + return self._call_rpc("is_clone", clone) + return False + + def get_model_object(self, name: str) -> Any: + return self._call_rpc("get_model_object", name) + + @property + def model_options(self) -> dict: + data = self._call_rpc("get_model_options") + import json + + def _decode_keys(obj): + if isinstance(obj, dict): + new_d = {} + for k, v in obj.items(): + if isinstance(k, str) and k.startswith("__pyisolate_key__"): + try: + json_str = k[17:] + val = json.loads(json_str) + if isinstance(val, list): + val = tuple(val) + new_d[val] = _decode_keys(v) + except: + new_d[k] = _decode_keys(v) + else: + new_d[k] = _decode_keys(v) + return new_d + if isinstance(obj, list): + return [_decode_keys(x) for x in obj] + return obj + + return _decode_keys(data) + + @model_options.setter + def model_options(self, value: dict) -> None: + self._call_rpc("set_model_options", value) + + def apply_hooks(self, hooks: Any) -> Any: + return self._call_rpc("apply_hooks", hooks) + + def prepare_state(self, timestep: Any) -> Any: + return self._call_rpc("prepare_state", timestep) + + def restore_hook_patches(self) -> None: + self._call_rpc("restore_hook_patches") + + def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None: + self._call_rpc("unpatch_hooks", whitelist_keys_set) + + def model_patches_to(self, device: Any) -> Any: + return self._call_rpc("model_patches_to", device) + + def partially_load( + self, device: Any, extra_memory: Any, force_patch_weights: bool = False + ) -> Any: + return self._call_rpc( + "partially_load", device, extra_memory, force_patch_weights + ) + + def partially_unload( + self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False + ) -> int: + return self._call_rpc( + "partially_unload", device_to, memory_to_free, force_patch_weights + ) + + def load( + self, + device_to: Any = None, + lowvram_model_memory: int = 0, + force_patch_weights: bool = False, + full_load: bool = False, + ) -> None: + self._call_rpc( + "load", device_to, lowvram_model_memory, force_patch_weights, full_load + ) + + def patch_model( + self, + device_to: Any = None, + lowvram_model_memory: int = 0, + load_weights: bool = True, + force_patch_weights: bool = False, + ) -> Any: + self._call_rpc( + "patch_model", + device_to, + lowvram_model_memory, + load_weights, + force_patch_weights, + ) + return self + + def unpatch_model( + self, device_to: Any = None, unpatch_weights: bool = True + ) -> None: + self._call_rpc("unpatch_model", device_to, unpatch_weights) + + def detach(self, unpatch_all: bool = True) -> Any: + self._call_rpc("detach", unpatch_all) + return self.model + + def _cpu_tensor_bytes(self, obj: Any) -> int: + import torch + + if isinstance(obj, torch.Tensor): + if obj.device.type == "cpu": + return obj.nbytes + return 0 + if isinstance(obj, dict): + return sum(self._cpu_tensor_bytes(v) for v in obj.values()) + if isinstance(obj, (list, tuple)): + return sum(self._cpu_tensor_bytes(v) for v in obj) + return 0 + + def _ensure_apply_model_headroom(self, required_bytes: int) -> bool: + if required_bytes <= 0: + return True + + import torch + import comfy.model_management as model_management + + target_raw = self.load_device + try: + if isinstance(target_raw, torch.device): + target = target_raw + elif isinstance(target_raw, str): + target = torch.device(target_raw) + elif isinstance(target_raw, int): + target = torch.device(f"cuda:{target_raw}") + else: + target = torch.device(target_raw) + except Exception: + return True + + if target.type != "cuda": + return True + + required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES + if model_management.get_free_memory(target) >= required: + return True + + model_management.cleanup_models_gc() + model_management.cleanup_models() + model_management.soft_empty_cache() + + if model_management.get_free_memory(target) < required: + model_management.free_memory(required, target, for_dynamic=True) + model_management.soft_empty_cache() + + if model_management.get_free_memory(target) < required: + # Escalate to non-dynamic unloading before dispatching CUDA transfer. + model_management.free_memory(required, target, for_dynamic=False) + model_management.soft_empty_cache() + + if model_management.get_free_memory(target) < required: + model_management.load_models_gpu( + [self], + minimum_memory_required=required, + ) + + return model_management.get_free_memory(target) >= required + + def apply_model(self, *args, **kwargs) -> Any: + import torch + + def _preferred_device() -> Any: + for value in args: + if isinstance(value, torch.Tensor): + return value.device + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + return value.device + return None + + def _move_result_to_device(obj: Any, device: Any) -> Any: + if device is None: + return obj + if isinstance(obj, torch.Tensor): + return obj.to(device) if obj.device != device else obj + if isinstance(obj, dict): + return {k: _move_result_to_device(v, device) for k, v in obj.items()} + if isinstance(obj, list): + return [_move_result_to_device(v, device) for v in obj] + if isinstance(obj, tuple): + return tuple(_move_result_to_device(v, device) for v in obj) + return obj + + # DynamicVRAM models must keep load/offload decisions in host process. + # Child-side CUDA staging here can deadlock before first inference RPC. + if self.is_dynamic(): + out = self._call_rpc("inner_model_apply_model", args, kwargs) + return _move_result_to_device(out, _preferred_device()) + + required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs) + self._ensure_apply_model_headroom(required_bytes) + + def _to_cuda(obj: Any) -> Any: + if isinstance(obj, torch.Tensor) and obj.device.type == "cpu": + return obj.to("cuda") + if isinstance(obj, dict): + return {k: _to_cuda(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cuda(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cuda(v) for v in obj) + return obj + + try: + args_cuda = _to_cuda(args) + kwargs_cuda = _to_cuda(kwargs) + except torch.OutOfMemoryError: + self._ensure_apply_model_headroom(required_bytes) + args_cuda = _to_cuda(args) + kwargs_cuda = _to_cuda(kwargs) + + out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda) + return _move_result_to_device(out, _preferred_device()) + + def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any: + keys = self._call_rpc("model_state_dict", filter_prefix) + return dict.fromkeys(keys, None) + + def add_patches(self, *args: Any, **kwargs: Any) -> Any: + res = self._call_rpc("add_patches", *args, **kwargs) + if isinstance(res, list): + return [tuple(x) if isinstance(x, list) else x for x in res] + return res + + def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any: + return self._call_rpc("get_key_patches", filter_prefix) + + def patch_weight_to_device(self, key, device_to=None, inplace_update=False): + self._call_rpc("patch_weight_to_device", key, device_to, inplace_update) + + def pin_weight_to_device(self, key): + self._call_rpc("pin_weight_to_device", key) + + def unpin_weight(self, key): + self._call_rpc("unpin_weight", key) + + def unpin_all_weights(self): + self._call_rpc("unpin_all_weights") + + def calculate_weight(self, patches, weight, key, intermediate_dtype=None): + return self._call_rpc( + "calculate_weight", patches, weight, key, intermediate_dtype + ) + + def inject_model(self) -> None: + self._call_rpc("inject_model") + + def eject_model(self) -> None: + self._call_rpc("eject_model") + + def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any: + return AutoPatcherEjector( + self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only + ) + + @property + def is_injected(self) -> bool: + return self._call_rpc("get_is_injected") + + @property + def skip_injection(self) -> bool: + return self._call_rpc("get_skip_injection") + + @skip_injection.setter + def skip_injection(self, value: bool) -> None: + self._call_rpc("set_skip_injection", value) + + def clean_hooks(self) -> None: + self._call_rpc("clean_hooks") + + def pre_run(self) -> None: + self._call_rpc("pre_run") + + def cleanup(self) -> None: + try: + self._call_rpc("cleanup") + except Exception: + logger.debug( + "ModelPatcherProxy cleanup RPC failed for %s", + self._instance_id, + exc_info=True, + ) + finally: + super().cleanup() + + @property + def model(self) -> _InnerModelProxy: + return _InnerModelProxy(self) + + def __getattr__(self, name: str) -> Any: + _whitelisted_attrs = { + "hook_patches_backup", + "hook_backup", + "cached_hook_patches", + "current_hooks", + "forced_hooks", + "is_clip", + "patches_uuid", + "pinned", + "attachments", + "additional_models", + "injections", + "hook_patches", + "model_lowvram", + "model_loaded_weight_memory", + "backup", + "object_patches_backup", + "weight_wrapper_patches", + "weight_inplace_update", + "force_cast_weights", + } + if name in _whitelisted_attrs: + return self._call_rpc("get_patcher_attr", name) + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def load_lora( + self, + lora_path: str, + strength_model: float, + clip: Optional[Any] = None, + strength_clip: float = 1.0, + ) -> tuple: + clip_id = None + if clip is not None: + clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None)) + result = self._call_rpc( + "load_lora", lora_path, strength_model, clip_id, strength_clip + ) + new_model = None + if result.get("model_id"): + new_model = ModelPatcherProxy( + result["model_id"], + self._registry, + manage_lifecycle=not IS_CHILD_PROCESS, + ) + new_clip = None + if result.get("clip_id"): + from comfy.isolation.clip_proxy import CLIPProxy + + new_clip = CLIPProxy(result["clip_id"]) + return (new_model, new_clip) + + @property + def load_device(self) -> Any: + return self._call_rpc("get_load_device") + + @property + def offload_device(self) -> Any: + return self._call_rpc("get_offload_device") + + @property + def device(self) -> Any: + return self.load_device + + def current_loaded_device(self) -> Any: + return self._call_rpc("current_loaded_device") + + @property + def size(self) -> int: + return self._call_rpc("get_size") + + def model_size(self) -> Any: + return self._call_rpc("model_size") + + def loaded_size(self) -> Any: + return self._call_rpc("loaded_size") + + def get_ram_usage(self) -> int: + return self._call_rpc("get_ram_usage") + + def lowvram_patch_counter(self) -> int: + return self._call_rpc("lowvram_patch_counter") + + def memory_required(self, input_shape: Any) -> Any: + return self._call_rpc("memory_required", input_shape) + + def get_operation_state(self) -> Dict[str, Any]: + state = self._call_rpc("get_operation_state") + return state if isinstance(state, dict) else {} + + def wait_for_idle(self, timeout_ms: int = 0) -> bool: + return bool(self._call_rpc("wait_for_idle", timeout_ms)) + + def is_dynamic(self) -> bool: + return bool(self._call_rpc("is_dynamic")) + + def get_free_memory(self, device: Any) -> Any: + return self._call_rpc("get_free_memory", device) + + def partially_unload_ram(self, ram_to_unload: int) -> Any: + return self._call_rpc("partially_unload_ram", ram_to_unload) + + def model_dtype(self) -> Any: + res = self._call_rpc("model_dtype") + if isinstance(res, str) and res.startswith("torch."): + try: + import torch + + attr = res.split(".")[-1] + if hasattr(torch, attr): + return getattr(torch, attr) + except ImportError: + pass + return res + + @property + def hook_mode(self) -> Any: + return self._call_rpc("get_hook_mode") + + @hook_mode.setter + def hook_mode(self, value: Any) -> None: + self._call_rpc("set_hook_mode", value) + + def set_model_sampler_cfg_function( + self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False + ) -> None: + self._call_rpc( + "set_model_sampler_cfg_function", + sampler_cfg_function, + disable_cfg1_optimization, + ) + + def set_model_sampler_post_cfg_function( + self, post_cfg_function: Any, disable_cfg1_optimization: bool = False + ) -> None: + self._call_rpc( + "set_model_sampler_post_cfg_function", + post_cfg_function, + disable_cfg1_optimization, + ) + + def set_model_sampler_pre_cfg_function( + self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False + ) -> None: + self._call_rpc( + "set_model_sampler_pre_cfg_function", + pre_cfg_function, + disable_cfg1_optimization, + ) + + def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None: + self._call_rpc("set_model_sampler_calc_cond_batch_function", fn) + + def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None: + self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function) + + def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None: + self._call_rpc("set_model_denoise_mask_function", denoise_mask_function) + + def set_model_patch(self, patch: Any, name: str) -> None: + self._call_rpc("set_model_patch", patch, name) + + def set_model_patch_replace( + self, + patch: Any, + name: str, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self._call_rpc( + "set_model_patch_replace", + patch, + name, + block_name, + number, + transformer_index, + ) + + def set_model_attn1_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn2_patch") + + def set_model_attn1_replace( + self, + patch: Any, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self.set_model_patch_replace( + patch, "attn1", block_name, number, transformer_index + ) + + def set_model_attn2_replace( + self, + patch: Any, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self.set_model_patch_replace( + patch, "attn2", block_name, number, transformer_index + ) + + def set_model_attn1_output_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn1_output_patch") + + def set_model_attn2_output_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn2_output_patch") + + def set_model_input_block_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "input_block_patch") + + def set_model_input_block_patch_after_skip(self, patch: Any) -> None: + self.set_model_patch(patch, "input_block_patch_after_skip") + + def set_model_output_block_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "output_block_patch") + + def set_model_emb_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "emb_patch") + + def set_model_forward_timestep_embed_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "forward_timestep_embed_patch") + + def set_model_double_block_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "double_block") + + def set_model_post_input_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "post_input") + + def set_model_rope_options( + self, + scale_x=1.0, + shift_x=0.0, + scale_y=1.0, + shift_y=0.0, + scale_t=1.0, + shift_t=0.0, + **kwargs: Any, + ) -> None: + options = { + "scale_x": scale_x, + "shift_x": shift_x, + "scale_y": scale_y, + "shift_y": shift_y, + "scale_t": scale_t, + "shift_t": shift_t, + } + options.update(kwargs) + self._call_rpc("set_model_rope_options", options) + + def set_model_compute_dtype(self, dtype: Any) -> None: + self._call_rpc("set_model_compute_dtype", dtype) + + def add_object_patch(self, name: str, obj: Any) -> None: + self._call_rpc("add_object_patch", name, obj) + + def add_weight_wrapper(self, name: str, function: Any) -> None: + self._call_rpc("add_weight_wrapper", name, function) + + def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None: + self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn) + + def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None: + self.add_wrapper_with_key(wrapper_type, None, wrapper) + + def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None: + self._call_rpc("remove_wrappers_with_key", wrapper_type, key) + + @property + def wrappers(self) -> Any: + return self._call_rpc("get_wrappers") + + def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None: + self._call_rpc("add_callback_with_key", call_type, key, callback) + + def add_callback(self, call_type: str, callback: Any) -> None: + self.add_callback_with_key(call_type, None, callback) + + def remove_callbacks_with_key(self, call_type: str, key: str) -> None: + self._call_rpc("remove_callbacks_with_key", call_type, key) + + @property + def callbacks(self) -> Any: + return self._call_rpc("get_callbacks") + + def set_attachments(self, key: str, attachment: Any) -> None: + self._call_rpc("set_attachments", key, attachment) + + def get_attachment(self, key: str) -> Any: + return self._call_rpc("get_attachment", key) + + def remove_attachments(self, key: str) -> None: + self._call_rpc("remove_attachments", key) + + def set_injections(self, key: str, injections: Any) -> None: + self._call_rpc("set_injections", key, injections) + + def get_injections(self, key: str) -> Any: + return self._call_rpc("get_injections", key) + + def remove_injections(self, key: str) -> None: + self._call_rpc("remove_injections", key) + + def set_additional_models(self, key: str, models: Any) -> None: + ids = [m._instance_id for m in models] + self._call_rpc("set_additional_models", key, ids) + + def remove_additional_models(self, key: str) -> None: + self._call_rpc("remove_additional_models", key) + + def get_nested_additional_models(self) -> Any: + return self._call_rpc("get_nested_additional_models") + + def get_additional_models(self) -> List[ModelPatcherProxy]: + ids = self._call_rpc("get_additional_models") + return [ + ModelPatcherProxy( + mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS + ) + for mid in ids + ] + + def model_patches_models(self) -> Any: + return self._call_rpc("model_patches_models") + + @property + def parent(self) -> Any: + return self._call_rpc("get_parent") + + +class _InnerModelProxy: + def __init__(self, parent: ModelPatcherProxy): + self._parent = parent + self._model_sampling = None + + def __getattr__(self, name: str) -> Any: + if name.startswith("_"): + raise AttributeError(name) + if name in ( + "model_config", + "latent_format", + "model_type", + "current_weight_patches_uuid", + ): + return self._parent._call_rpc("get_inner_model_attr", name) + if name == "load_device": + return self._parent._call_rpc("get_inner_model_attr", "load_device") + if name == "device": + return self._parent._call_rpc("get_inner_model_attr", "device") + if name == "current_patcher": + return ModelPatcherProxy( + self._parent._instance_id, + self._parent._registry, + manage_lifecycle=False, + ) + if name == "model_sampling": + if self._model_sampling is None: + self._model_sampling = self._parent._call_rpc( + "get_model_object", "model_sampling" + ) + return self._model_sampling + if name == "extra_conds_shapes": + return lambda *a, **k: self._parent._call_rpc( + "inner_model_extra_conds_shapes", a, k + ) + if name == "extra_conds": + return lambda *a, **k: self._parent._call_rpc( + "inner_model_extra_conds", a, k + ) + if name == "memory_required": + return lambda *a, **k: self._parent._call_rpc( + "inner_model_memory_required", a, k + ) + if name == "apply_model": + # Delegate to parent's method to get the CPU->CUDA optimization + return self._parent.apply_model + if name == "process_latent_in": + return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k) + if name == "process_latent_out": + return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k) + if name == "scale_latent_inpaint": + return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k) + if name == "diffusion_model": + return self._parent._call_rpc("get_inner_model_attr", "diffusion_model") + raise AttributeError(f"'{name}' not supported on isolated InnerModel") diff --git a/comfy/isolation/model_patcher_proxy_registry.py b/comfy/isolation/model_patcher_proxy_registry.py new file mode 100644 index 000000000..c696f6a0a --- /dev/null +++ b/comfy/isolation/model_patcher_proxy_registry.py @@ -0,0 +1,1230 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,unused-import +# RPC server for ModelPatcher isolation (child process) +from __future__ import annotations + +import asyncio +import gc +import logging +import threading +import time +from dataclasses import dataclass, field +from typing import Any, Optional, List + +try: + from comfy.model_patcher import AutoPatcherEjector +except ImportError: + + class AutoPatcherEjector: + def __init__(self, model, skip_and_inject_on_exit_only=False): + self.model = model + self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only + self.prev_skip_injection = False + self.was_injected = False + + def __enter__(self): + self.was_injected = False + self.prev_skip_injection = self.model.skip_injection + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = True + if self.model.is_injected: + self.model.eject_model() + self.was_injected = True + + def __exit__(self, *args): + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = self.prev_skip_injection + self.model.inject_model() + if self.was_injected and not self.model.skip_injection: + self.model.inject_model() + self.model.skip_injection = self.prev_skip_injection + + +from comfy.isolation.proxies.base import ( + BaseRegistry, + detach_if_grad, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class _OperationState: + lease: threading.Lock = field(default_factory=threading.Lock) + active_count: int = 0 + active_by_method: dict[str, int] = field(default_factory=dict) + total_operations: int = 0 + last_method: Optional[str] = None + last_started_ts: Optional[float] = None + last_ended_ts: Optional[float] = None + last_elapsed_ms: Optional[float] = None + last_error: Optional[str] = None + last_thread_id: Optional[int] = None + last_loop_id: Optional[int] = None + + +class ModelPatcherRegistry(BaseRegistry[Any]): + _type_prefix = "model" + + def __init__(self) -> None: + super().__init__() + self._pending_cleanup_ids: set[str] = set() + self._operation_states: dict[str, _OperationState] = {} + self._operation_state_cv = threading.Condition(self._lock) + + def _get_or_create_operation_state(self, instance_id: str) -> _OperationState: + state = self._operation_states.get(instance_id) + if state is None: + state = _OperationState() + self._operation_states[instance_id] = state + return state + + def _begin_operation(self, instance_id: str, method_name: str) -> tuple[float, float]: + start_epoch = time.time() + start_perf = time.perf_counter() + with self._operation_state_cv: + state = self._get_or_create_operation_state(instance_id) + state.active_count += 1 + state.active_by_method[method_name] = ( + state.active_by_method.get(method_name, 0) + 1 + ) + state.total_operations += 1 + state.last_method = method_name + state.last_started_ts = start_epoch + state.last_thread_id = threading.get_ident() + try: + state.last_loop_id = id(asyncio.get_running_loop()) + except RuntimeError: + state.last_loop_id = None + logger.debug( + "ISO:registry_op_start instance_id=%s method=%s start_ts=%.6f thread=%s loop=%s", + instance_id, + method_name, + start_epoch, + threading.get_ident(), + state.last_loop_id, + ) + return start_epoch, start_perf + + def _end_operation( + self, + instance_id: str, + method_name: str, + start_perf: float, + error: Optional[BaseException] = None, + ) -> None: + end_epoch = time.time() + elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 + with self._operation_state_cv: + state = self._get_or_create_operation_state(instance_id) + state.active_count = max(0, state.active_count - 1) + if method_name in state.active_by_method: + remaining = state.active_by_method[method_name] - 1 + if remaining <= 0: + state.active_by_method.pop(method_name, None) + else: + state.active_by_method[method_name] = remaining + state.last_ended_ts = end_epoch + state.last_elapsed_ms = elapsed_ms + state.last_error = None if error is None else repr(error) + if state.active_count == 0: + self._operation_state_cv.notify_all() + logger.debug( + "ISO:registry_op_end instance_id=%s method=%s end_ts=%.6f elapsed_ms=%.3f error=%s", + instance_id, + method_name, + end_epoch, + elapsed_ms, + None if error is None else type(error).__name__, + ) + + def _run_operation_with_lease(self, instance_id: str, method_name: str, fn): + with self._operation_state_cv: + state = self._get_or_create_operation_state(instance_id) + lease = state.lease + with lease: + _, start_perf = self._begin_operation(instance_id, method_name) + try: + result = fn() + except Exception as exc: + self._end_operation(instance_id, method_name, start_perf, error=exc) + raise + self._end_operation(instance_id, method_name, start_perf) + return result + + def _snapshot_operation_state(self, instance_id: str) -> dict[str, Any]: + with self._operation_state_cv: + state = self._operation_states.get(instance_id) + if state is None: + return { + "instance_id": instance_id, + "active_count": 0, + "active_methods": [], + "total_operations": 0, + "last_method": None, + "last_started_ts": None, + "last_ended_ts": None, + "last_elapsed_ms": None, + "last_error": None, + "last_thread_id": None, + "last_loop_id": None, + } + return { + "instance_id": instance_id, + "active_count": state.active_count, + "active_methods": sorted(state.active_by_method.keys()), + "total_operations": state.total_operations, + "last_method": state.last_method, + "last_started_ts": state.last_started_ts, + "last_ended_ts": state.last_ended_ts, + "last_elapsed_ms": state.last_elapsed_ms, + "last_error": state.last_error, + "last_thread_id": state.last_thread_id, + "last_loop_id": state.last_loop_id, + } + + def unregister_sync(self, instance_id: str) -> None: + with self._operation_state_cv: + instance = self._registry.pop(instance_id, None) + if instance is not None: + self._id_map.pop(id(instance), None) + self._pending_cleanup_ids.discard(instance_id) + self._operation_states.pop(instance_id, None) + self._operation_state_cv.notify_all() + + async def get_operation_state(self, instance_id: str) -> dict[str, Any]: + return self._snapshot_operation_state(instance_id) + + async def get_all_operation_states(self) -> dict[str, dict[str, Any]]: + with self._operation_state_cv: + ids = sorted(self._operation_states.keys()) + return {instance_id: self._snapshot_operation_state(instance_id) for instance_id in ids} + + async def wait_for_idle(self, instance_id: str, timeout_ms: int = 0) -> bool: + timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0) + deadline = None if timeout_s is None else (time.monotonic() + timeout_s) + with self._operation_state_cv: + while True: + active = self._operation_states.get(instance_id) + if active is None or active.active_count == 0: + return True + if deadline is None: + self._operation_state_cv.wait() + continue + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + self._operation_state_cv.wait(timeout=remaining) + + async def wait_all_idle(self, timeout_ms: int = 0) -> bool: + timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0) + deadline = None if timeout_s is None else (time.monotonic() + timeout_s) + with self._operation_state_cv: + while True: + has_active = any( + state.active_count > 0 for state in self._operation_states.values() + ) + if not has_active: + return True + if deadline is None: + self._operation_state_cv.wait() + continue + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + self._operation_state_cv.wait(timeout=remaining) + + async def clone(self, instance_id: str) -> str: + instance = self._get_instance(instance_id) + new_model = instance.clone() + return self.register(new_model) + + async def is_clone(self, instance_id: str, other: Any) -> bool: + instance = self._get_instance(instance_id) + if hasattr(other, "model"): + return instance.is_clone(other) + return False + + async def get_model_object(self, instance_id: str, name: str) -> Any: + instance = self._get_instance(instance_id) + if name == "model": + return f"" + result = instance.get_model_object(name) + if name == "model_sampling": + from comfy.isolation.model_sampling_proxy import ( + ModelSamplingRegistry, + ModelSamplingProxy, + ) + + registry = ModelSamplingRegistry() + # Preserve identity when upstream already returned a proxy. Re-registering + # a proxy object creates proxy-of-proxy call chains. + if isinstance(result, ModelSamplingProxy): + sampling_id = result._instance_id + else: + sampling_id = registry.register(result) + return ModelSamplingProxy(sampling_id, registry) + + return detach_if_grad(result) + + async def get_model_options(self, instance_id: str) -> dict: + instance = self._get_instance(instance_id) + import copy + + opts = copy.deepcopy(instance.model_options) + return self._sanitize_rpc_result(opts) + + async def set_model_options(self, instance_id: str, options: dict) -> None: + self._get_instance(instance_id).model_options = options + + async def get_patcher_attr(self, instance_id: str, name: str) -> Any: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), name, None) + ) + + async def model_state_dict(self, instance_id: str, filter_prefix=None) -> Any: + instance = self._get_instance(instance_id) + sd_keys = instance.model.state_dict().keys() + return dict.fromkeys(sd_keys, None) + + def _sanitize_rpc_result(self, obj, seen=None): + if seen is None: + seen = set() + if obj is None: + return None + if isinstance(obj, (bool, int, float, str)): + if isinstance(obj, str) and len(obj) > 500000: + return f"" + return obj + obj_id = id(obj) + if obj_id in seen: + return None + seen.add(obj_id) + if isinstance(obj, (list, tuple)): + return [self._sanitize_rpc_result(x, seen) for x in obj] + if isinstance(obj, set): + return [self._sanitize_rpc_result(x, seen) for x in obj] + if isinstance(obj, dict): + new_dict = {} + for k, v in obj.items(): + if isinstance(k, tuple): + import json + + try: + key_str = "__pyisolate_key__" + json.dumps(list(k)) + new_dict[key_str] = self._sanitize_rpc_result(v, seen) + except Exception: + new_dict[str(k)] = self._sanitize_rpc_result(v, seen) + else: + new_dict[str(k)] = self._sanitize_rpc_result(v, seen) + return new_dict + if ( + hasattr(obj, "__dict__") + and not hasattr(obj, "__get__") + and not hasattr(obj, "__call__") + ): + return self._sanitize_rpc_result(obj.__dict__, seen) + if hasattr(obj, "items") and hasattr(obj, "get"): + return {str(k): self._sanitize_rpc_result(v, seen) for k, v in obj.items()} + return None + + async def get_load_device(self, instance_id: str) -> Any: + return self._get_instance(instance_id).load_device + + async def get_offload_device(self, instance_id: str) -> Any: + return self._get_instance(instance_id).offload_device + + async def current_loaded_device(self, instance_id: str) -> Any: + return self._get_instance(instance_id).current_loaded_device() + + async def get_size(self, instance_id: str) -> int: + return self._get_instance(instance_id).size + + async def model_size(self, instance_id: str) -> Any: + return self._get_instance(instance_id).model_size() + + async def loaded_size(self, instance_id: str) -> Any: + return self._get_instance(instance_id).loaded_size() + + async def get_ram_usage(self, instance_id: str) -> int: + return self._get_instance(instance_id).get_ram_usage() + + async def lowvram_patch_counter(self, instance_id: str) -> int: + return self._get_instance(instance_id).lowvram_patch_counter() + + async def memory_required(self, instance_id: str, input_shape: Any) -> Any: + return self._run_operation_with_lease( + instance_id, + "memory_required", + lambda: self._get_instance(instance_id).memory_required(input_shape), + ) + + async def is_dynamic(self, instance_id: str) -> bool: + instance = self._get_instance(instance_id) + if hasattr(instance, "is_dynamic"): + return bool(instance.is_dynamic()) + return False + + async def get_free_memory(self, instance_id: str, device: Any) -> Any: + instance = self._get_instance(instance_id) + if hasattr(instance, "get_free_memory"): + return instance.get_free_memory(device) + import comfy.model_management + + return comfy.model_management.get_free_memory(device) + + async def partially_unload_ram(self, instance_id: str, ram_to_unload: int) -> Any: + instance = self._get_instance(instance_id) + if hasattr(instance, "partially_unload_ram"): + return instance.partially_unload_ram(ram_to_unload) + return None + + async def model_dtype(self, instance_id: str) -> Any: + return self._run_operation_with_lease( + instance_id, + "model_dtype", + lambda: self._get_instance(instance_id).model_dtype(), + ) + + async def model_patches_to(self, instance_id: str, device: Any) -> Any: + return self._get_instance(instance_id).model_patches_to(device) + + async def partially_load( + self, + instance_id: str, + device: Any, + extra_memory: Any, + force_patch_weights: bool = False, + ) -> Any: + return self._run_operation_with_lease( + instance_id, + "partially_load", + lambda: self._get_instance(instance_id).partially_load( + device, extra_memory, force_patch_weights=force_patch_weights + ), + ) + + async def partially_unload( + self, + instance_id: str, + device_to: Any, + memory_to_free: int = 0, + force_patch_weights: bool = False, + ) -> int: + return self._run_operation_with_lease( + instance_id, + "partially_unload", + lambda: self._get_instance(instance_id).partially_unload( + device_to, memory_to_free, force_patch_weights + ), + ) + + async def load( + self, + instance_id: str, + device_to: Any = None, + lowvram_model_memory: int = 0, + force_patch_weights: bool = False, + full_load: bool = False, + ) -> None: + self._run_operation_with_lease( + instance_id, + "load", + lambda: self._get_instance(instance_id).load( + device_to, lowvram_model_memory, force_patch_weights, full_load + ), + ) + + async def patch_model( + self, + instance_id: str, + device_to: Any = None, + lowvram_model_memory: int = 0, + load_weights: bool = True, + force_patch_weights: bool = False, + ) -> None: + def _invoke() -> None: + try: + self._get_instance(instance_id).patch_model( + device_to, lowvram_model_memory, load_weights, force_patch_weights + ) + except AttributeError as e: + logger.error( + f"Isolation Error: Failed to patch model attribute: {e}. Skipping." + ) + return + + self._run_operation_with_lease(instance_id, "patch_model", _invoke) + + async def unpatch_model( + self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True + ) -> None: + self._run_operation_with_lease( + instance_id, + "unpatch_model", + lambda: self._get_instance(instance_id).unpatch_model( + device_to, unpatch_weights + ), + ) + + async def detach(self, instance_id: str, unpatch_all: bool = True) -> None: + self._get_instance(instance_id).detach(unpatch_all) + + async def prepare_state(self, instance_id: str, timestep: Any) -> Any: + instance = self._get_instance(instance_id) + cp = getattr(instance.model, "current_patcher", instance) + if cp is None: + cp = instance + return cp.prepare_state(timestep) + + async def pre_run(self, instance_id: str) -> None: + self._get_instance(instance_id).pre_run() + + async def cleanup(self, instance_id: str) -> None: + def _invoke() -> None: + try: + instance = self._get_instance(instance_id) + except Exception: + logger.debug( + "ModelPatcher cleanup requested for missing instance %s", + instance_id, + exc_info=True, + ) + return + + try: + instance.cleanup() + finally: + with self._lock: + self._pending_cleanup_ids.add(instance_id) + gc.collect() + + self._run_operation_with_lease(instance_id, "cleanup", _invoke) + + def sweep_pending_cleanup(self) -> int: + removed = 0 + with self._operation_state_cv: + pending_ids = list(self._pending_cleanup_ids) + self._pending_cleanup_ids.clear() + for instance_id in pending_ids: + instance = self._registry.pop(instance_id, None) + if instance is None: + continue + self._id_map.pop(id(instance), None) + self._operation_states.pop(instance_id, None) + removed += 1 + self._operation_state_cv.notify_all() + + gc.collect() + return removed + + def purge_all(self) -> int: + with self._operation_state_cv: + removed = len(self._registry) + self._registry.clear() + self._id_map.clear() + self._pending_cleanup_ids.clear() + self._operation_states.clear() + self._operation_state_cv.notify_all() + gc.collect() + return removed + + async def apply_hooks(self, instance_id: str, hooks: Any) -> Any: + instance = self._get_instance(instance_id) + cp = getattr(instance.model, "current_patcher", instance) + if cp is None: + cp = instance + return cp.apply_hooks(hooks=hooks) + + async def clean_hooks(self, instance_id: str) -> None: + self._get_instance(instance_id).clean_hooks() + + async def restore_hook_patches(self, instance_id: str) -> None: + self._get_instance(instance_id).restore_hook_patches() + + async def unpatch_hooks( + self, instance_id: str, whitelist_keys_set: Optional[set] = None + ) -> None: + self._get_instance(instance_id).unpatch_hooks(whitelist_keys_set) + + async def register_all_hook_patches( + self, + instance_id: str, + hooks: Any, + target_dict: Any, + model_options: Any, + registered: Any, + ) -> None: + from types import SimpleNamespace + import comfy.hooks + + instance = self._get_instance(instance_id) + if isinstance(hooks, SimpleNamespace) or hasattr(hooks, "__dict__"): + hook_data = hooks.__dict__ if hasattr(hooks, "__dict__") else hooks + new_hooks = comfy.hooks.HookGroup() + if hasattr(hook_data, "hooks"): + new_hooks.hooks = ( + hook_data["hooks"] + if isinstance(hook_data, dict) + else hook_data.hooks + ) + hooks = new_hooks + instance.register_all_hook_patches( + hooks, target_dict, model_options, registered + ) + + async def get_hook_mode(self, instance_id: str) -> Any: + return getattr(self._get_instance(instance_id), "hook_mode", None) + + async def set_hook_mode(self, instance_id: str, value: Any) -> None: + setattr(self._get_instance(instance_id), "hook_mode", value) + + async def inject_model(self, instance_id: str) -> None: + instance = self._get_instance(instance_id) + try: + instance.inject_model() + except AttributeError as e: + if "inject" in str(e): + logger.error( + "Isolation Error: Injector object lost method code during serialization. Cannot inject. Skipping." + ) + return + raise e + + async def eject_model(self, instance_id: str) -> None: + self._get_instance(instance_id).eject_model() + + async def get_is_injected(self, instance_id: str) -> bool: + return self._get_instance(instance_id).is_injected + + async def set_skip_injection(self, instance_id: str, value: bool) -> None: + self._get_instance(instance_id).skip_injection = value + + async def get_skip_injection(self, instance_id: str) -> bool: + return self._get_instance(instance_id).skip_injection + + async def set_model_sampler_cfg_function( + self, + instance_id: str, + sampler_cfg_function: Any, + disable_cfg1_optimization: bool = False, + ) -> None: + if not callable(sampler_cfg_function): + logger.error( + f"set_model_sampler_cfg_function: Expected callable, got {type(sampler_cfg_function)}. Skipping." + ) + return + self._get_instance(instance_id).set_model_sampler_cfg_function( + sampler_cfg_function, disable_cfg1_optimization + ) + + async def set_model_sampler_post_cfg_function( + self, + instance_id: str, + post_cfg_function: Any, + disable_cfg1_optimization: bool = False, + ) -> None: + self._get_instance(instance_id).set_model_sampler_post_cfg_function( + post_cfg_function, disable_cfg1_optimization + ) + + async def set_model_sampler_pre_cfg_function( + self, + instance_id: str, + pre_cfg_function: Any, + disable_cfg1_optimization: bool = False, + ) -> None: + self._get_instance(instance_id).set_model_sampler_pre_cfg_function( + pre_cfg_function, disable_cfg1_optimization + ) + + async def set_model_sampler_calc_cond_batch_function( + self, instance_id: str, fn: Any + ) -> None: + self._get_instance(instance_id).set_model_sampler_calc_cond_batch_function(fn) + + async def set_model_unet_function_wrapper( + self, instance_id: str, unet_wrapper_function: Any + ) -> None: + self._get_instance(instance_id).set_model_unet_function_wrapper( + unet_wrapper_function + ) + + async def set_model_denoise_mask_function( + self, instance_id: str, denoise_mask_function: Any + ) -> None: + self._get_instance(instance_id).set_model_denoise_mask_function( + denoise_mask_function + ) + + async def set_model_patch(self, instance_id: str, patch: Any, name: str) -> None: + self._get_instance(instance_id).set_model_patch(patch, name) + + async def set_model_patch_replace( + self, + instance_id: str, + patch: Any, + name: str, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self._get_instance(instance_id).set_model_patch_replace( + patch, name, block_name, number, transformer_index + ) + + async def set_model_input_block_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_input_block_patch(patch) + + async def set_model_input_block_patch_after_skip( + self, instance_id: str, patch: Any + ) -> None: + self._get_instance(instance_id).set_model_input_block_patch_after_skip(patch) + + async def set_model_output_block_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_output_block_patch(patch) + + async def set_model_emb_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_emb_patch(patch) + + async def set_model_forward_timestep_embed_patch( + self, instance_id: str, patch: Any + ) -> None: + self._get_instance(instance_id).set_model_forward_timestep_embed_patch(patch) + + async def set_model_double_block_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_double_block_patch(patch) + + async def set_model_post_input_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_post_input_patch(patch) + + async def set_model_rope_options(self, instance_id: str, options: dict) -> None: + self._get_instance(instance_id).set_model_rope_options(**options) + + async def set_model_compute_dtype(self, instance_id: str, dtype: Any) -> None: + self._get_instance(instance_id).set_model_compute_dtype(dtype) + + async def clone_has_same_weights_by_id( + self, instance_id: str, other_id: str + ) -> bool: + instance = self._get_instance(instance_id) + other = self._get_instance(other_id) + if not other: + return False + return instance.clone_has_same_weights(other) + + async def load_list_internal(self, instance_id: str, *args, **kwargs) -> Any: + return self._get_instance(instance_id)._load_list(*args, **kwargs) + + async def is_clone_by_id(self, instance_id: str, other_id: str) -> bool: + instance = self._get_instance(instance_id) + other = self._get_instance(other_id) + if hasattr(instance, "is_clone"): + return instance.is_clone(other) + return False + + async def add_object_patch(self, instance_id: str, name: str, obj: Any) -> None: + self._get_instance(instance_id).add_object_patch(name, obj) + + async def add_weight_wrapper( + self, instance_id: str, name: str, function: Any + ) -> None: + self._get_instance(instance_id).add_weight_wrapper(name, function) + + async def add_wrapper_with_key( + self, instance_id: str, wrapper_type: Any, key: str, fn: Any + ) -> None: + self._get_instance(instance_id).add_wrapper_with_key(wrapper_type, key, fn) + + async def remove_wrappers_with_key( + self, instance_id: str, wrapper_type: str, key: str + ) -> None: + self._get_instance(instance_id).remove_wrappers_with_key(wrapper_type, key) + + async def get_wrappers( + self, instance_id: str, wrapper_type: str = None, key: str = None + ) -> Any: + if wrapper_type is None and key is None: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "wrappers", {}) + ) + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_wrappers(wrapper_type, key) + ) + + async def get_all_wrappers(self, instance_id: str, wrapper_type: str = None) -> Any: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "get_all_wrappers", lambda x: [])( + wrapper_type + ) + ) + + async def add_callback_with_key( + self, instance_id: str, call_type: str, key: str, callback: Any + ) -> None: + self._get_instance(instance_id).add_callback_with_key(call_type, key, callback) + + async def remove_callbacks_with_key( + self, instance_id: str, call_type: str, key: str + ) -> None: + self._get_instance(instance_id).remove_callbacks_with_key(call_type, key) + + async def get_callbacks( + self, instance_id: str, call_type: str = None, key: str = None + ) -> Any: + if call_type is None and key is None: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "callbacks", {}) + ) + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_callbacks(call_type, key) + ) + + async def get_all_callbacks(self, instance_id: str, call_type: str = None) -> Any: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "get_all_callbacks", lambda x: [])( + call_type + ) + ) + + async def set_attachments( + self, instance_id: str, key: str, attachment: Any + ) -> None: + self._get_instance(instance_id).set_attachments(key, attachment) + + async def get_attachment(self, instance_id: str, key: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_attachment(key) + ) + + async def remove_attachments(self, instance_id: str, key: str) -> None: + self._get_instance(instance_id).remove_attachments(key) + + async def set_injections(self, instance_id: str, key: str, injections: Any) -> None: + self._get_instance(instance_id).set_injections(key, injections) + + async def get_injections(self, instance_id: str, key: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_injections(key) + ) + + async def remove_injections(self, instance_id: str, key: str) -> None: + self._get_instance(instance_id).remove_injections(key) + + async def set_additional_models( + self, instance_id: str, key: str, models: Any + ) -> None: + self._get_instance(instance_id).set_additional_models(key, models) + + async def remove_additional_models(self, instance_id: str, key: str) -> None: + self._get_instance(instance_id).remove_additional_models(key) + + async def get_nested_additional_models(self, instance_id: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_nested_additional_models() + ) + + async def get_additional_models(self, instance_id: str) -> List[str]: + models = self._get_instance(instance_id).get_additional_models() + return [self.register(m) for m in models] + + async def get_additional_models_with_key(self, instance_id: str, key: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_additional_models_with_key(key) + ) + + async def model_patches_models(self, instance_id: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).model_patches_models() + ) + + async def get_patches(self, instance_id: str) -> Any: + return self._sanitize_rpc_result(self._get_instance(instance_id).patches.copy()) + + async def get_object_patches(self, instance_id: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).object_patches.copy() + ) + + async def add_patches( + self, + instance_id: str, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> Any: + return self._get_instance(instance_id).add_patches( + patches, strength_patch, strength_model + ) + + async def get_key_patches( + self, instance_id: str, filter_prefix: Optional[str] = None + ) -> Any: + res = self._get_instance(instance_id).get_key_patches() + if filter_prefix: + res = {k: v for k, v in res.items() if k.startswith(filter_prefix)} + safe_res = {} + for k, v in res.items(): + safe_res[k] = [ + f"" + if hasattr(t, "shape") + else str(t) + for t in v + ] + return safe_res + + async def add_hook_patches( + self, + instance_id: str, + hook: Any, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> None: + if hasattr(hook, "hook_ref") and isinstance(hook.hook_ref, dict): + try: + hook.hook_ref = tuple(sorted(hook.hook_ref.items())) + except Exception: + hook.hook_ref = None + self._get_instance(instance_id).add_hook_patches( + hook, patches, strength_patch, strength_model + ) + + async def get_combined_hook_patches(self, instance_id: str, hooks: Any) -> Any: + if hooks is not None and hasattr(hooks, "hooks"): + for hook in getattr(hooks, "hooks", []): + hook_ref = getattr(hook, "hook_ref", None) + if isinstance(hook_ref, dict): + try: + hook.hook_ref = tuple(sorted(hook_ref.items())) + except Exception: + hook.hook_ref = None + res = self._get_instance(instance_id).get_combined_hook_patches(hooks) + return self._sanitize_rpc_result(res) + + async def clear_cached_hook_weights(self, instance_id: str) -> None: + self._get_instance(instance_id).clear_cached_hook_weights() + + async def prepare_hook_patches_current_keyframe( + self, instance_id: str, t: Any, hook_group: Any, model_options: Any + ) -> None: + self._get_instance(instance_id).prepare_hook_patches_current_keyframe( + t, hook_group, model_options + ) + + async def get_parent(self, instance_id: str) -> Any: + return getattr(self._get_instance(instance_id), "parent", None) + + async def patch_weight_to_device( + self, + instance_id: str, + key: str, + device_to: Any = None, + inplace_update: bool = False, + ) -> None: + self._get_instance(instance_id).patch_weight_to_device( + key, device_to, inplace_update + ) + + async def pin_weight_to_device(self, instance_id: str, key: str) -> None: + instance = self._get_instance(instance_id) + if hasattr(instance, "pinned") and isinstance(instance.pinned, list): + instance.pinned = set(instance.pinned) + instance.pin_weight_to_device(key) + + async def unpin_weight(self, instance_id: str, key: str) -> None: + instance = self._get_instance(instance_id) + if hasattr(instance, "pinned") and isinstance(instance.pinned, list): + instance.pinned = set(instance.pinned) + instance.unpin_weight(key) + + async def unpin_all_weights(self, instance_id: str) -> None: + instance = self._get_instance(instance_id) + if hasattr(instance, "pinned") and isinstance(instance.pinned, list): + instance.pinned = set(instance.pinned) + instance.unpin_all_weights() + + async def calculate_weight( + self, + instance_id: str, + patches: Any, + weight: Any, + key: str, + intermediate_dtype: Any = float, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).calculate_weight( + patches, weight, key, intermediate_dtype + ) + ) + + async def get_inner_model_attr(self, instance_id: str, name: str) -> Any: + try: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id).model, name) + ) + except AttributeError: + return None + + async def inner_model_memory_required( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + return self._run_operation_with_lease( + instance_id, + "inner_model_memory_required", + lambda: self._get_instance(instance_id).model.memory_required( + *args, **kwargs + ), + ) + + async def inner_model_extra_conds_shapes( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + return self._run_operation_with_lease( + instance_id, + "inner_model_extra_conds_shapes", + lambda: self._get_instance(instance_id).model.extra_conds_shapes( + *args, **kwargs + ), + ) + + async def inner_model_extra_conds( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + def _invoke() -> Any: + result = self._get_instance(instance_id).model.extra_conds(*args, **kwargs) + try: + import torch + import comfy.conds + except Exception: + return result + + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + if isinstance(obj, comfy.conds.CONDRegular): + return type(obj)(_to_cpu(obj.cond)) + return obj + + return _to_cpu(result) + + return self._run_operation_with_lease(instance_id, "inner_model_extra_conds", _invoke) + + async def inner_model_state_dict( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + sd = self._get_instance(instance_id).model.state_dict(*args, **kwargs) + return { + k: {"numel": v.numel(), "element_size": v.element_size()} + for k, v in sd.items() + } + + async def inner_model_apply_model( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + def _invoke() -> Any: + import torch + + instance = self._get_instance(instance_id) + target = getattr(instance, "load_device", None) + if target is None and args and hasattr(args[0], "device"): + target = args[0].device + elif target is None: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + + def _move(obj): + if target is None: + return obj + if isinstance(obj, (tuple, list)): + return type(obj)(_move(o) for o in obj) + if hasattr(obj, "to"): + return obj.to(target) + return obj + + moved_args = tuple(_move(a) for a in args) + moved_kwargs = {k: _move(v) for k, v in kwargs.items()} + result = instance.model.apply_model(*moved_args, **moved_kwargs) + moved_result = detach_if_grad(_move(result)) + + # DynamicVRAM + isolation: returning CUDA tensors across RPC can stall + # at the transport boundary. Marshal dynamic-path results as CPU and let + # the proxy restore device placement in the child process. + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(moved_result) + return moved_result + + return self._run_operation_with_lease(instance_id, "inner_model_apply_model", _invoke) + + async def process_latent_in( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + import torch + + def _invoke() -> Any: + instance = self._get_instance(instance_id) + result = detach_if_grad(instance.model.process_latent_in(*args, **kwargs)) + + # DynamicVRAM + isolation: returning CUDA tensors across RPC can stall + # at the transport boundary. Marshal dynamic-path results as CPU and let + # the proxy restore placement when needed. + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(result) + return result + + return self._run_operation_with_lease(instance_id, "process_latent_in", _invoke) + + async def process_latent_out( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + import torch + + def _invoke() -> Any: + instance = self._get_instance(instance_id) + result = instance.model.process_latent_out(*args, **kwargs) + moved_result = None + try: + target = None + if args and hasattr(args[0], "device"): + target = args[0].device + elif kwargs: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + if target is not None and hasattr(result, "to"): + moved_result = detach_if_grad(result.to(target)) + except Exception: + logger.debug( + "process_latent_out: failed to move result to target device", + exc_info=True, + ) + if moved_result is None: + moved_result = detach_if_grad(result) + + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(moved_result) + return moved_result + + return self._run_operation_with_lease(instance_id, "process_latent_out", _invoke) + + async def scale_latent_inpaint( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + import torch + + def _invoke() -> Any: + instance = self._get_instance(instance_id) + result = instance.model.scale_latent_inpaint(*args, **kwargs) + moved_result = None + try: + target = None + if args and hasattr(args[0], "device"): + target = args[0].device + elif kwargs: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + if target is not None and hasattr(result, "to"): + moved_result = detach_if_grad(result.to(target)) + except Exception: + logger.debug( + "scale_latent_inpaint: failed to move result to target device", + exc_info=True, + ) + if moved_result is None: + moved_result = detach_if_grad(result) + + is_dynamic_fn = getattr(instance, "is_dynamic", None) + if callable(is_dynamic_fn) and is_dynamic_fn(): + def _to_cpu(obj: Any) -> Any: + if torch.is_tensor(obj): + return obj.detach().cpu() if obj.device.type != "cpu" else obj + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cpu(v) for v in obj) + return obj + + return _to_cpu(moved_result) + return moved_result + + return self._run_operation_with_lease( + instance_id, "scale_latent_inpaint", _invoke + ) + + async def load_lora( + self, + instance_id: str, + lora_path: str, + strength_model: float, + clip_id: Optional[str] = None, + strength_clip: float = 1.0, + ) -> dict: + import comfy.utils + import comfy.sd + import folder_paths + from comfy.isolation.clip_proxy import CLIPRegistry + + model = self._get_instance(instance_id) + clip = None + if clip_id: + clip = CLIPRegistry()._get_instance(clip_id) + lora_full_path = folder_paths.get_full_path("loras", lora_path) + if lora_full_path is None: + raise ValueError(f"LoRA file not found: {lora_path}") + lora = comfy.utils.load_torch_file(lora_full_path) + new_model, new_clip = comfy.sd.load_lora_for_models( + model, clip, lora, strength_model, strength_clip + ) + new_model_id = self.register(new_model) if new_model else None + new_clip_id = ( + CLIPRegistry().register(new_clip) if (new_clip and clip_id) else None + ) + return {"model_id": new_model_id, "clip_id": new_clip_id} diff --git a/comfy/isolation/model_patcher_proxy_utils.py b/comfy/isolation/model_patcher_proxy_utils.py new file mode 100644 index 000000000..038687f01 --- /dev/null +++ b/comfy/isolation/model_patcher_proxy_utils.py @@ -0,0 +1,156 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access +# Isolation utilities and serializers for ModelPatcherProxy +from __future__ import annotations + +import logging +import os +from typing import Any + +from comfy.cli_args import args + +logger = logging.getLogger(__name__) + + +def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any: + from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + isolation_active = args.use_process_isolation or is_child + + if not isolation_active: + return model_patcher + if is_child: + return model_patcher + if isinstance(model_patcher, ModelPatcherProxy): + return model_patcher + + registry = ModelPatcherRegistry() + model_id = registry.register(model_patcher) + logger.debug(f"Isolated ModelPatcher: {model_id}") + return ModelPatcherProxy(model_id, registry, manage_lifecycle=True) + + +def register_hooks_serializers(registry=None): + from pyisolate._internal.serialization_registry import SerializerRegistry + import comfy.hooks + + if registry is None: + registry = SerializerRegistry.get_instance() + + def serialize_enum(obj): + return {"__enum__": f"{type(obj).__name__}.{obj.name}"} + + def deserialize_enum(data): + cls_name, val_name = data["__enum__"].split(".") + cls = getattr(comfy.hooks, cls_name) + return cls[val_name] + + registry.register("EnumHookType", serialize_enum, deserialize_enum) + registry.register("EnumHookScope", serialize_enum, deserialize_enum) + registry.register("EnumHookMode", serialize_enum, deserialize_enum) + registry.register("EnumWeightTarget", serialize_enum, deserialize_enum) + + def serialize_hook_group(obj): + return {"__type__": "HookGroup", "hooks": obj.hooks} + + def deserialize_hook_group(data): + hg = comfy.hooks.HookGroup() + for h in data["hooks"]: + hg.add(h) + return hg + + registry.register("HookGroup", serialize_hook_group, deserialize_hook_group) + + def serialize_dict_state(obj): + d = obj.__dict__.copy() + d["__type__"] = type(obj).__name__ + if "custom_should_register" in d: + del d["custom_should_register"] + return d + + def deserialize_dict_state_generic(cls): + def _deserialize(data): + h = cls() + h.__dict__.update(data) + return h + + return _deserialize + + def deserialize_hook_keyframe(data): + h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0)) + h.__dict__.update(data) + return h + + registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe) + + def deserialize_hook_keyframe_group(data): + h = comfy.hooks.HookKeyframeGroup() + h.__dict__.update(data) + return h + + registry.register( + "HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group + ) + + def deserialize_hook(data): + h = comfy.hooks.Hook() + h.__dict__.update(data) + return h + + registry.register("Hook", serialize_dict_state, deserialize_hook) + + def deserialize_weight_hook(data): + h = comfy.hooks.WeightHook() + h.__dict__.update(data) + return h + + registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook) + + def serialize_set(obj): + return {"__set__": list(obj)} + + def deserialize_set(data): + return set(data["__set__"]) + + registry.register("set", serialize_set, deserialize_set) + + try: + from comfy.weight_adapter.lora import LoRAAdapter + + def serialize_lora(obj): + return {"weights": {}, "loaded_keys": list(obj.loaded_keys)} + + def deserialize_lora(data): + return LoRAAdapter(set(data["loaded_keys"]), data["weights"]) + + registry.register("LoRAAdapter", serialize_lora, deserialize_lora) + except Exception: + pass + + try: + from comfy.hooks import _HookRef + import uuid + + def serialize_hook_ref(obj): + return { + "__hook_ref__": True, + "id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())), + } + + def deserialize_hook_ref(data): + h = _HookRef() + h._pyisolate_id = data.get("id", str(uuid.uuid4())) + return h + + registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref) + except ImportError: + pass + except Exception as e: + logger.warning(f"Failed to register _HookRef: {e}") + + +try: + register_hooks_serializers() +except Exception as e: + logger.error(f"Failed to initialize hook serializers: {e}") diff --git a/comfy/isolation/model_sampling_proxy.py b/comfy/isolation/model_sampling_proxy.py new file mode 100644 index 000000000..8fbfc5b93 --- /dev/null +++ b/comfy/isolation/model_sampling_proxy.py @@ -0,0 +1,360 @@ +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +import asyncio +import logging +import os +import threading +import time +from typing import Any + +from comfy.isolation.proxies.base import ( + BaseProxy, + BaseRegistry, + detach_if_grad, + get_thread_loop, + run_coro_in_new_loop, +) + +logger = logging.getLogger(__name__) + + +def _describe_value(obj: Any) -> str: + try: + import torch + except Exception: + torch = None + try: + if torch is not None and isinstance(obj, torch.Tensor): + return ( + "Tensor(shape=%s,dtype=%s,device=%s,id=%s)" + % (tuple(obj.shape), obj.dtype, obj.device, id(obj)) + ) + except Exception: + pass + return "%s(id=%s)" % (type(obj).__name__, id(obj)) + + +def _prefer_device(*tensors: Any) -> Any: + try: + import torch + except Exception: + return None + for t in tensors: + if isinstance(t, torch.Tensor) and t.is_cuda: + return t.device + for t in tensors: + if isinstance(t, torch.Tensor): + return t.device + return None + + +def _to_device(obj: Any, device: Any) -> Any: + try: + import torch + except Exception: + return obj + if device is None: + return obj + if isinstance(obj, torch.Tensor): + if obj.device != device: + return obj.to(device) + return obj + if isinstance(obj, (list, tuple)): + converted = [_to_device(x, device) for x in obj] + return type(obj)(converted) if isinstance(obj, tuple) else converted + if isinstance(obj, dict): + return {k: _to_device(v, device) for k, v in obj.items()} + return obj + + +def _to_cpu_for_rpc(obj: Any) -> Any: + try: + import torch + except Exception: + return obj + if isinstance(obj, torch.Tensor): + t = obj.detach() if obj.requires_grad else obj + if t.is_cuda: + return t.to("cpu") + return t + if isinstance(obj, (list, tuple)): + converted = [_to_cpu_for_rpc(x) for x in obj] + return type(obj)(converted) if isinstance(obj, tuple) else converted + if isinstance(obj, dict): + return {k: _to_cpu_for_rpc(v) for k, v in obj.items()} + return obj + + +class ModelSamplingRegistry(BaseRegistry[Any]): + _type_prefix = "modelsampling" + + async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.calculate_input(sigma, noise)) + + async def calculate_denoised( + self, instance_id: str, sigma: Any, model_output: Any, model_input: Any + ) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad( + sampling.calculate_denoised(sigma, model_output, model_input) + ) + + async def noise_scaling( + self, + instance_id: str, + sigma: Any, + noise: Any, + latent_image: Any, + max_denoise: bool = False, + ) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad( + sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise) + ) + + async def inverse_noise_scaling( + self, instance_id: str, sigma: Any, latent: Any + ) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent)) + + async def timestep(self, instance_id: str, sigma: Any) -> Any: + sampling = self._get_instance(instance_id) + return sampling.timestep(sigma) + + async def sigma(self, instance_id: str, timestep: Any) -> Any: + sampling = self._get_instance(instance_id) + return sampling.sigma(timestep) + + async def percent_to_sigma(self, instance_id: str, percent: float) -> Any: + sampling = self._get_instance(instance_id) + return sampling.percent_to_sigma(percent) + + async def get_sigma_min(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigma_min) + + async def get_sigma_max(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigma_max) + + async def get_sigma_data(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigma_data) + + async def get_sigmas(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigmas) + + async def set_sigmas(self, instance_id: str, sigmas: Any) -> None: + sampling = self._get_instance(instance_id) + sampling.set_sigmas(sigmas) + + +class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]): + _registry_class = ModelSamplingRegistry + __module__ = "comfy.isolation.model_sampling_proxy" + + def _get_rpc(self) -> Any: + if self._rpc_caller is None: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc is not None: + self._rpc_caller = rpc.create_caller( + ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id() + ) + else: + registry = ModelSamplingRegistry() + + class _LocalCaller: + def calculate_input( + self, instance_id: str, sigma: Any, noise: Any + ) -> Any: + return registry.calculate_input(instance_id, sigma, noise) + + def calculate_denoised( + self, + instance_id: str, + sigma: Any, + model_output: Any, + model_input: Any, + ) -> Any: + return registry.calculate_denoised( + instance_id, sigma, model_output, model_input + ) + + def noise_scaling( + self, + instance_id: str, + sigma: Any, + noise: Any, + latent_image: Any, + max_denoise: bool = False, + ) -> Any: + return registry.noise_scaling( + instance_id, sigma, noise, latent_image, max_denoise + ) + + def inverse_noise_scaling( + self, instance_id: str, sigma: Any, latent: Any + ) -> Any: + return registry.inverse_noise_scaling( + instance_id, sigma, latent + ) + + def timestep(self, instance_id: str, sigma: Any) -> Any: + return registry.timestep(instance_id, sigma) + + def sigma(self, instance_id: str, timestep: Any) -> Any: + return registry.sigma(instance_id, timestep) + + def percent_to_sigma(self, instance_id: str, percent: float) -> Any: + return registry.percent_to_sigma(instance_id, percent) + + def get_sigma_min(self, instance_id: str) -> Any: + return registry.get_sigma_min(instance_id) + + def get_sigma_max(self, instance_id: str) -> Any: + return registry.get_sigma_max(instance_id) + + def get_sigma_data(self, instance_id: str) -> Any: + return registry.get_sigma_data(instance_id) + + def get_sigmas(self, instance_id: str) -> Any: + return registry.get_sigmas(instance_id) + + def set_sigmas(self, instance_id: str, sigmas: Any) -> None: + return registry.set_sigmas(instance_id, sigmas) + + self._rpc_caller = _LocalCaller() + return self._rpc_caller + + def _call(self, method_name: str, *args: Any) -> Any: + rpc = self._get_rpc() + method = getattr(rpc, method_name) + result = method(self._instance_id, *args) + timeout_ms = self._rpc_timeout_ms() + start_epoch = time.time() + start_perf = time.perf_counter() + thread_id = threading.get_ident() + call_id = "%s:%s:%s:%.6f" % ( + self._instance_id, + method_name, + thread_id, + start_perf, + ) + logger.debug( + "ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s", + method_name, + self._instance_id, + call_id, + start_epoch, + thread_id, + timeout_ms, + ) + if asyncio.iscoroutine(result): + result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0) + try: + asyncio.get_running_loop() + out = run_coro_in_new_loop(result) + except RuntimeError: + loop = get_thread_loop() + out = loop.run_until_complete(result) + else: + out = result + logger.debug( + "ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s", + method_name, + self._instance_id, + call_id, + _describe_value(out), + ) + elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 + logger.debug( + "ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s", + method_name, + self._instance_id, + call_id, + elapsed_ms, + thread_id, + ) + logger.debug( + "ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s", + method_name, + self._instance_id, + call_id, + ) + return out + + @staticmethod + def _rpc_timeout_ms() -> int: + raw = os.environ.get( + "COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS", + os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"), + ) + try: + timeout_ms = int(raw) + except ValueError: + timeout_ms = 30000 + return max(1, timeout_ms) + + @property + def sigma_min(self) -> Any: + return self._call("get_sigma_min") + + @property + def sigma_max(self) -> Any: + return self._call("get_sigma_max") + + @property + def sigma_data(self) -> Any: + return self._call("get_sigma_data") + + @property + def sigmas(self) -> Any: + return self._call("get_sigmas") + + def calculate_input(self, sigma: Any, noise: Any) -> Any: + return self._call("calculate_input", sigma, noise) + + def calculate_denoised( + self, sigma: Any, model_output: Any, model_input: Any + ) -> Any: + return self._call("calculate_denoised", sigma, model_output, model_input) + + def noise_scaling( + self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False + ) -> Any: + preferred_device = _prefer_device(noise, latent_image) + out = self._call( + "noise_scaling", + _to_cpu_for_rpc(sigma), + _to_cpu_for_rpc(noise), + _to_cpu_for_rpc(latent_image), + max_denoise, + ) + return _to_device(out, preferred_device) + + def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any: + preferred_device = _prefer_device(latent) + out = self._call( + "inverse_noise_scaling", + _to_cpu_for_rpc(sigma), + _to_cpu_for_rpc(latent), + ) + return _to_device(out, preferred_device) + + def timestep(self, sigma: Any) -> Any: + return self._call("timestep", sigma) + + def sigma(self, timestep: Any) -> Any: + return self._call("sigma", timestep) + + def percent_to_sigma(self, percent: float) -> Any: + return self._call("percent_to_sigma", percent) + + def set_sigmas(self, sigmas: Any) -> None: + return self._call("set_sigmas", sigmas) diff --git a/comfy/isolation/proxies/__init__.py b/comfy/isolation/proxies/__init__.py new file mode 100644 index 000000000..30d0089ad --- /dev/null +++ b/comfy/isolation/proxies/__init__.py @@ -0,0 +1,17 @@ +from .base import ( + IS_CHILD_PROCESS, + BaseProxy, + BaseRegistry, + detach_if_grad, + get_thread_loop, + run_coro_in_new_loop, +) + +__all__ = [ + "IS_CHILD_PROCESS", + "BaseRegistry", + "BaseProxy", + "get_thread_loop", + "run_coro_in_new_loop", + "detach_if_grad", +] diff --git a/comfy/isolation/proxies/base.py b/comfy/isolation/proxies/base.py new file mode 100644 index 000000000..71cc1943c --- /dev/null +++ b/comfy/isolation/proxies/base.py @@ -0,0 +1,283 @@ +# pylint: disable=global-statement,import-outside-toplevel,protected-access +from __future__ import annotations + +import asyncio +import concurrent.futures +import logging +import os +import threading +import time +import weakref +from typing import Any, Callable, Dict, Generic, Optional, TypeVar + +try: + from pyisolate import ProxiedSingleton +except ImportError: + + class ProxiedSingleton: # type: ignore[no-redef] + pass + + +logger = logging.getLogger(__name__) + +IS_CHILD_PROCESS = os.environ.get("PYISOLATE_CHILD") == "1" +_thread_local = threading.local() +T = TypeVar("T") + + +def get_thread_loop() -> asyncio.AbstractEventLoop: + loop = getattr(_thread_local, "loop", None) + if loop is None or loop.is_closed(): + loop = asyncio.new_event_loop() + _thread_local.loop = loop + return loop + + +def run_coro_in_new_loop(coro: Any) -> Any: + result_box: Dict[str, Any] = {} + exc_box: Dict[str, BaseException] = {} + + def runner() -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result_box["value"] = loop.run_until_complete(coro) + except Exception as exc: # noqa: BLE001 + exc_box["exc"] = exc + finally: + loop.close() + + t = threading.Thread(target=runner, daemon=True) + t.start() + t.join() + if "exc" in exc_box: + raise exc_box["exc"] + return result_box.get("value") + + +def detach_if_grad(obj: Any) -> Any: + try: + import torch + except Exception: + return obj + + if isinstance(obj, torch.Tensor): + return obj.detach() if obj.requires_grad else obj + if isinstance(obj, (list, tuple)): + return type(obj)(detach_if_grad(x) for x in obj) + if isinstance(obj, dict): + return {k: detach_if_grad(v) for k, v in obj.items()} + return obj + + +class BaseRegistry(ProxiedSingleton, Generic[T]): + _type_prefix: str = "base" + + def __init__(self) -> None: + if hasattr(ProxiedSingleton, "__init__") and ProxiedSingleton is not object: + super().__init__() + self._registry: Dict[str, T] = {} + self._id_map: Dict[int, str] = {} + self._counter = 0 + self._lock = threading.Lock() + + def register(self, instance: T) -> str: + with self._lock: + obj_id = id(instance) + if obj_id in self._id_map: + return self._id_map[obj_id] + instance_id = f"{self._type_prefix}_{self._counter}" + self._counter += 1 + self._registry[instance_id] = instance + self._id_map[obj_id] = instance_id + return instance_id + + def unregister_sync(self, instance_id: str) -> None: + with self._lock: + instance = self._registry.pop(instance_id, None) + if instance: + self._id_map.pop(id(instance), None) + + def _get_instance(self, instance_id: str) -> T: + if IS_CHILD_PROCESS: + raise RuntimeError( + f"[{self.__class__.__name__}] _get_instance called in child" + ) + with self._lock: + instance = self._registry.get(instance_id) + if instance is None: + raise ValueError(f"{instance_id} not found") + return instance + + +_GLOBAL_LOOP: Optional[asyncio.AbstractEventLoop] = None + + +def set_global_loop(loop: asyncio.AbstractEventLoop) -> None: + global _GLOBAL_LOOP + _GLOBAL_LOOP = loop + + +class BaseProxy(Generic[T]): + _registry_class: type = BaseRegistry # type: ignore[type-arg] + __module__: str = "comfy.isolation.proxies.base" + _TIMEOUT_RPC_METHODS = frozenset( + { + "partially_load", + "partially_unload", + "load", + "patch_model", + "unpatch_model", + "inner_model_apply_model", + "memory_required", + "model_dtype", + "inner_model_memory_required", + "inner_model_extra_conds_shapes", + "inner_model_extra_conds", + "process_latent_in", + "process_latent_out", + "scale_latent_inpaint", + } + ) + + def __init__( + self, + instance_id: str, + registry: Optional[Any] = None, + manage_lifecycle: bool = False, + ) -> None: + self._instance_id = instance_id + self._rpc_caller: Optional[Any] = None + self._registry = registry if registry is not None else self._registry_class() + self._manage_lifecycle = manage_lifecycle + self._cleaned_up = False + if manage_lifecycle and not IS_CHILD_PROCESS: + self._finalizer = weakref.finalize( + self, self._registry.unregister_sync, instance_id + ) + + def _get_rpc(self) -> Any: + if self._rpc_caller is None: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc is None: + raise RuntimeError(f"[{self.__class__.__name__}] No RPC in child") + self._rpc_caller = rpc.create_caller( + self._registry_class, self._registry_class.get_remote_id() + ) + return self._rpc_caller + + def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]: + if method_name not in self._TIMEOUT_RPC_METHODS: + return None + try: + timeout_ms = int( + os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000") + ) + except ValueError: + timeout_ms = 120000 + return max(1, timeout_ms) + + def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + rpc = self._get_rpc() + method = getattr(rpc, method_name) + timeout_ms = self._rpc_timeout_ms_for_method(method_name) + coro = method(self._instance_id, *args, **kwargs) + if timeout_ms is not None: + coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0) + + start_epoch = time.time() + start_perf = time.perf_counter() + thread_id = threading.get_ident() + try: + running_loop = asyncio.get_running_loop() + loop_id: Optional[int] = id(running_loop) + except RuntimeError: + loop_id = None + logger.debug( + "ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f " + "thread=%s loop=%s timeout_ms=%s", + self.__class__.__name__, + method_name, + self._instance_id, + start_epoch, + thread_id, + loop_id, + timeout_ms, + ) + + try: + # If we have a global loop (Main Thread Loop), use it for dispatch from worker threads + if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running(): + try: + curr_loop = asyncio.get_running_loop() + if curr_loop is _GLOBAL_LOOP: + pass + except RuntimeError: + # No running loop - we are in a worker thread. + future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP) + return future.result( + timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None + ) + + try: + asyncio.get_running_loop() + return run_coro_in_new_loop(coro) + except RuntimeError: + loop = get_thread_loop() + return loop.run_until_complete(coro) + except asyncio.TimeoutError as exc: + raise TimeoutError( + f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} " + f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})" + ) from exc + except concurrent.futures.TimeoutError as exc: + raise TimeoutError( + f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} " + f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})" + ) from exc + finally: + end_epoch = time.time() + elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 + logger.debug( + "ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f " + "elapsed_ms=%.3f thread=%s loop=%s", + self.__class__.__name__, + method_name, + self._instance_id, + end_epoch, + elapsed_ms, + thread_id, + loop_id, + ) + + def __getstate__(self) -> Dict[str, Any]: + return {"_instance_id": self._instance_id} + + def __setstate__(self, state: Dict[str, Any]) -> None: + self._instance_id = state["_instance_id"] + self._rpc_caller = None + self._registry = self._registry_class() + self._manage_lifecycle = False + self._cleaned_up = False + + def cleanup(self) -> None: + if self._cleaned_up or IS_CHILD_PROCESS: + return + self._cleaned_up = True + finalizer = getattr(self, "_finalizer", None) + if finalizer is not None: + finalizer.detach() + self._registry.unregister_sync(self._instance_id) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self._instance_id}>" + + +def create_rpc_method(method_name: str) -> Callable[..., Any]: + def method(self: BaseProxy[Any], *args: Any, **kwargs: Any) -> Any: + return self._call_rpc(method_name, *args, **kwargs) + + method.__name__ = method_name + return method diff --git a/comfy/isolation/proxies/folder_paths_proxy.py b/comfy/isolation/proxies/folder_paths_proxy.py new file mode 100644 index 000000000..a2996ec24 --- /dev/null +++ b/comfy/isolation/proxies/folder_paths_proxy.py @@ -0,0 +1,29 @@ +from __future__ import annotations +from typing import Dict + +import folder_paths +from pyisolate import ProxiedSingleton + + +class FolderPathsProxy(ProxiedSingleton): + """ + Dynamic proxy for folder_paths. + Uses __getattr__ for most lookups, with explicit handling for + mutable collections to ensure efficient by-value transfer. + """ + + def __getattr__(self, name): + return getattr(folder_paths, name) + + # Return dict snapshots (avoid RPC chatter) + @property + def folder_names_and_paths(self) -> Dict: + return dict(folder_paths.folder_names_and_paths) + + @property + def extension_mimetypes_cache(self) -> Dict: + return dict(folder_paths.extension_mimetypes_cache) + + @property + def filename_list_cache(self) -> Dict: + return dict(folder_paths.filename_list_cache) diff --git a/comfy/isolation/proxies/helper_proxies.py b/comfy/isolation/proxies/helper_proxies.py new file mode 100644 index 000000000..a50b9e4c4 --- /dev/null +++ b/comfy/isolation/proxies/helper_proxies.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + + +class AnyTypeProxy(str): + """Replacement for custom AnyType objects used by some nodes.""" + + def __new__(cls, value: str = "*"): + return super().__new__(cls, value) + + def __ne__(self, other): # type: ignore[override] + return False + + +class FlexibleOptionalInputProxy(dict): + """Replacement for FlexibleOptionalInputType to allow dynamic inputs.""" + + def __init__(self, flex_type, data: Optional[Dict[str, object]] = None): + super().__init__() + self.type = flex_type + if data: + self.update(data) + + def __getitem__(self, key): # type: ignore[override] + return (self.type,) + + def __contains__(self, key): # type: ignore[override] + return True + + +class ByPassTypeTupleProxy(tuple): + """Replacement for ByPassTypeTuple to mirror wildcard fallback behavior.""" + + def __new__(cls, values): + return super().__new__(cls, values) + + def __getitem__(self, index): # type: ignore[override] + if index >= len(self): + return AnyTypeProxy("*") + return super().__getitem__(index) + + +def _restore_special_value(value: Any) -> Any: + if isinstance(value, dict): + if value.get("__pyisolate_any_type__"): + return AnyTypeProxy(value.get("value", "*")) + if value.get("__pyisolate_flexible_optional__"): + flex_type = _restore_special_value(value.get("type")) + data_raw = value.get("data") + data = ( + {k: _restore_special_value(v) for k, v in data_raw.items()} + if isinstance(data_raw, dict) + else {} + ) + return FlexibleOptionalInputProxy(flex_type, data) + if value.get("__pyisolate_tuple__") is not None: + return tuple( + _restore_special_value(v) for v in value["__pyisolate_tuple__"] + ) + if value.get("__pyisolate_bypass_tuple__") is not None: + return ByPassTypeTupleProxy( + tuple( + _restore_special_value(v) + for v in value["__pyisolate_bypass_tuple__"] + ) + ) + return {k: _restore_special_value(v) for k, v in value.items()} + if isinstance(value, list): + return [_restore_special_value(v) for v in value] + return value + + +def restore_input_types(raw: Dict[str, object]) -> Dict[str, object]: + """Restore serialized INPUT_TYPES payload back into ComfyUI-compatible objects.""" + + if not isinstance(raw, dict): + return raw # type: ignore[return-value] + + restored: Dict[str, object] = {} + for section, entries in raw.items(): + if isinstance(entries, dict) and entries.get("__pyisolate_flexible_optional__"): + restored[section] = _restore_special_value(entries) + elif isinstance(entries, dict): + restored[section] = { + k: _restore_special_value(v) for k, v in entries.items() + } + else: + restored[section] = _restore_special_value(entries) + return restored + + +__all__ = [ + "AnyTypeProxy", + "FlexibleOptionalInputProxy", + "ByPassTypeTupleProxy", + "restore_input_types", +] diff --git a/comfy/isolation/proxies/model_management_proxy.py b/comfy/isolation/proxies/model_management_proxy.py new file mode 100644 index 000000000..00e14d9b4 --- /dev/null +++ b/comfy/isolation/proxies/model_management_proxy.py @@ -0,0 +1,27 @@ +import comfy.model_management as mm +from pyisolate import ProxiedSingleton + + +class ModelManagementProxy(ProxiedSingleton): + """ + Dynamic proxy for comfy.model_management. + Uses __getattr__ to forward all calls to the underlying module, + reducing maintenance burden. + """ + + # Explicitly expose Enums/Classes as properties + @property + def VRAMState(self): + return mm.VRAMState + + @property + def CPUState(self): + return mm.CPUState + + @property + def OOM_EXCEPTION(self): + return mm.OOM_EXCEPTION + + def __getattr__(self, name): + """Forward all other attribute access to the module.""" + return getattr(mm, name) diff --git a/comfy/isolation/proxies/progress_proxy.py b/comfy/isolation/proxies/progress_proxy.py new file mode 100644 index 000000000..44494ea31 --- /dev/null +++ b/comfy/isolation/proxies/progress_proxy.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import logging +from typing import Any, Optional + +try: + from pyisolate import ProxiedSingleton +except ImportError: + + class ProxiedSingleton: + pass + + +from comfy_execution.progress import get_progress_state + +logger = logging.getLogger(__name__) + + +class ProgressProxy(ProxiedSingleton): + def set_progress( + self, + value: float, + max_value: float, + node_id: Optional[str] = None, + image: Any = None, + ) -> None: + get_progress_state().update_progress( + node_id=node_id, + value=value, + max_value=max_value, + image=image, + ) + + +__all__ = ["ProgressProxy"] diff --git a/comfy/isolation/proxies/prompt_server_impl.py b/comfy/isolation/proxies/prompt_server_impl.py new file mode 100644 index 000000000..2a775e097 --- /dev/null +++ b/comfy/isolation/proxies/prompt_server_impl.py @@ -0,0 +1,265 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,redefined-outer-name,reimported,super-init-not-called +"""Stateless RPC Implementation for PromptServer. + +Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub architecture. +- Host: PromptServerService (RPC Handler) +- Child: PromptServerStub (Interface Implementation) +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any, Dict, Optional, Callable + +import logging +from aiohttp import web + +# IMPORTS +from pyisolate import ProxiedSingleton + +logger = logging.getLogger(__name__) +LOG_PREFIX = "[Isolation:C<->H]" + +# ... + +# ============================================================================= +# CHILD SIDE: PromptServerStub +# ============================================================================= + + +class PromptServerStub: + """Stateless Stub for PromptServer.""" + + # Masquerade as the real server module + __module__ = "server" + + _instance: Optional["PromptServerStub"] = None + _rpc: Optional[Any] = None # This will be the Caller object + _source_file: Optional[str] = None + + def __init__(self): + self.routes = RouteStub(self) + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + """Inject RPC client (called by adapter.py or manually).""" + # Create caller for HOST Service + # Assuming Host Service is registered as "PromptServerService" (class name) + # We target the Host Service Class + target_id = "PromptServerService" + # We need to pass a class to create_caller? Usually yes. + # But we don't have the Service class imported here necessarily (if running on child). + # pyisolate check verify_service type? + # If we pass PromptServerStub as the 'class', it might mismatch if checking types. + # But we can try passing PromptServerStub if it mirrors the service name? No, stub is PromptServerStub. + # We need a dummy class with right name? + # Or just rely on string ID if create_caller supports it? + # Standard: rpc.create_caller(PromptServerStub, target_id) + # But wait, PromptServerStub is the *Local* class. + # We want to call *Remote* class. + # If we use PromptServerStub as the type, returning object will be typed as PromptServerStub? + # The first arg is 'service_cls'. + cls._rpc = rpc.create_caller( + PromptServerService, target_id + ) # We import Service below? + + # We need PromptServerService available for the create_caller call? + # Or just use the Stub class if ID matches? + # prompt_server_impl.py defines BOTH. So PromptServerService IS available! + + @property + def instance(self) -> "PromptServerStub": + return self + + # ... Compatibility ... + @classmethod + def _get_source_file(cls) -> str: + if cls._source_file is None: + import folder_paths + + cls._source_file = os.path.join(folder_paths.base_path, "server.py") + return cls._source_file + + @property + def __file__(self) -> str: + return self._get_source_file() + + # --- Properties --- + @property + def client_id(self) -> Optional[str]: + return "isolated_client" + + def supports(self, feature: str) -> bool: + return True + + @property + def app(self): + raise RuntimeError( + "PromptServer.app is not accessible in isolated nodes. Use RPC routes instead." + ) + + @property + def prompt_queue(self): + raise RuntimeError( + "PromptServer.prompt_queue is not accessible in isolated nodes." + ) + + # --- UI Communication (RPC Delegates) --- + async def send_sync( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ) -> None: + if self._rpc: + await self._rpc.ui_send_sync(event, data, sid) + + async def send( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ) -> None: + if self._rpc: + await self._rpc.ui_send(event, data, sid) + + def send_progress_text(self, text: str, node_id: str, sid=None) -> None: + if self._rpc: + # Fire and forget likely needed. If method is async on host, caller invocation returns coroutine. + # We must schedule it? + # Or use fire_remote equivalent? + # Caller object usually proxies calls. If host method is async, it returns coro. + # If we are sync here (send_progress_text checks imply sync usage), we must background it. + # But UtilsProxy hook wrapper creates task. + # Does send_progress_text need to be sync? Yes, node code calls it sync. + import asyncio + + try: + loop = asyncio.get_running_loop() + loop.create_task(self._rpc.ui_send_progress_text(text, node_id, sid)) + except RuntimeError: + pass # Sync context without loop? + + # --- Route Registration Logic --- + def register_route(self, method: str, path: str, handler: Callable): + """Register a route handler via RPC.""" + if not self._rpc: + logger.error("RPC not initialized in PromptServerStub") + return + + # Fire registration async + try: + loop = asyncio.get_running_loop() + loop.create_task(self._rpc.register_route_rpc(method, path, handler)) + except RuntimeError: + pass + + +class RouteStub: + """Simulates aiohttp.web.RouteTableDef.""" + + def __init__(self, stub: PromptServerStub): + self._stub = stub + + def get(self, path: str): + def decorator(handler): + self._stub.register_route("GET", path, handler) + return handler + + return decorator + + def post(self, path: str): + def decorator(handler): + self._stub.register_route("POST", path, handler) + return handler + + return decorator + + def patch(self, path: str): + def decorator(handler): + self._stub.register_route("PATCH", path, handler) + return handler + + return decorator + + def put(self, path: str): + def decorator(handler): + self._stub.register_route("PUT", path, handler) + return handler + + return decorator + + def delete(self, path: str): + def decorator(handler): + self._stub.register_route("DELETE", path, handler) + return handler + + return decorator + + +# ============================================================================= +# HOST SIDE: PromptServerService +# ============================================================================= + + +class PromptServerService(ProxiedSingleton): + """Host-side RPC Service for PromptServer.""" + + def __init__(self): + # We will bind to the real server instance lazily or via global import + pass + + @property + def server(self): + from server import PromptServer + + return PromptServer.instance + + async def ui_send_sync( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ): + await self.server.send_sync(event, data, sid) + + async def ui_send( + self, event: str, data: Dict[str, Any], sid: Optional[str] = None + ): + await self.server.send(event, data, sid) + + async def ui_send_progress_text(self, text: str, node_id: str, sid=None): + # Made async to be awaitable by RPC layer + self.server.send_progress_text(text, node_id, sid) + + async def register_route_rpc(self, method: str, path: str, child_handler_proxy): + """RPC Target: Register a route that forwards to the Child.""" + logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}") + + async def route_wrapper(request: web.Request) -> web.Response: + # 1. Capture request data + req_data = { + "method": request.method, + "path": request.path, + "query": dict(request.query), + } + if request.can_read_body: + req_data["text"] = await request.text() + + try: + # 2. Call Child Handler via RPC (child_handler_proxy is async callable) + result = await child_handler_proxy(req_data) + + # 3. Serialize Response + return self._serialize_response(result) + except Exception as e: + logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}") + return web.Response(status=500, text=str(e)) + + # Register loop + self.server.app.router.add_route(method, path, route_wrapper) + + def _serialize_response(self, result: Any) -> web.Response: + """Helper to convert Child result -> web.Response""" + if isinstance(result, web.Response): + return result + # Handle dict (json) + if isinstance(result, dict): + return web.json_response(result) + # Handle string + if isinstance(result, str): + return web.Response(text=result) + # Fallback + return web.Response(text=str(result)) diff --git a/comfy/isolation/proxies/utils_proxy.py b/comfy/isolation/proxies/utils_proxy.py new file mode 100644 index 000000000..432f7ec90 --- /dev/null +++ b/comfy/isolation/proxies/utils_proxy.py @@ -0,0 +1,64 @@ +# pylint: disable=cyclic-import,import-outside-toplevel +from __future__ import annotations + +from typing import Optional, Any +import comfy.utils +from pyisolate import ProxiedSingleton + +import os + + +class UtilsProxy(ProxiedSingleton): + """ + Proxy for comfy.utils. + Primarily handles the PROGRESS_BAR_HOOK to ensure progress updates + from isolated nodes reach the host. + """ + + # _instance and __new__ removed to rely on SingletonMetaclass + _rpc: Optional[Any] = None + + @classmethod + def set_rpc(cls, rpc: Any) -> None: + # Create caller using class name as ID (standard for Singletons) + cls._rpc = rpc.create_caller(cls, "UtilsProxy") + + async def progress_bar_hook( + self, + value: int, + total: int, + preview: Optional[bytes] = None, + node_id: Optional[str] = None, + ) -> Any: + """ + Host-side implementation: forwards the call to the real global hook. + Child-side: this method call is intercepted by RPC and sent to host. + """ + if os.environ.get("PYISOLATE_CHILD") == "1": + # Manual RPC dispatch for Child process + # Use class-level RPC storage (Static Injection) + if UtilsProxy._rpc: + return await UtilsProxy._rpc.progress_bar_hook( + value, total, preview, node_id + ) + + # Fallback channel: global child rpc + try: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + get_child_rpc_instance() + # If we have an RPC instance but no UtilsProxy._rpc, we *could* try to use it, + # but we need a caller. For now, just pass to avoid crashing. + pass + except (ImportError, LookupError): + pass + + return None + + # Host Execution + if comfy.utils.PROGRESS_BAR_HOOK is not None: + comfy.utils.PROGRESS_BAR_HOOK(value, total, preview, node_id) + + def set_progress_bar_global_hook(self, hook: Any) -> None: + """Forward hook registration (though usually not needed from child).""" + comfy.utils.set_progress_bar_global_hook(hook) diff --git a/comfy/isolation/rpc_bridge.py b/comfy/isolation/rpc_bridge.py new file mode 100644 index 000000000..2beb0f09f --- /dev/null +++ b/comfy/isolation/rpc_bridge.py @@ -0,0 +1,49 @@ +import asyncio +import logging +import threading + +logger = logging.getLogger(__name__) + + +class RpcBridge: + """Minimal helper to run coroutines synchronously inside isolated processes. + + If an event loop is already running, the coroutine is executed on a fresh + thread with its own loop to avoid nested run_until_complete errors. + """ + + def run_sync(self, maybe_coro): + if not asyncio.iscoroutine(maybe_coro): + return maybe_coro + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + result_container = {} + exc_container = {} + + def _runner(): + try: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + result_container["value"] = new_loop.run_until_complete(maybe_coro) + except Exception as exc: # pragma: no cover + exc_container["error"] = exc + finally: + try: + new_loop.close() + except Exception: + pass + + t = threading.Thread(target=_runner, daemon=True) + t.start() + t.join() + + if "error" in exc_container: + raise exc_container["error"] + return result_container.get("value") + + return asyncio.run(maybe_coro) diff --git a/comfy/isolation/runtime_helpers.py b/comfy/isolation/runtime_helpers.py new file mode 100644 index 000000000..767e222f2 --- /dev/null +++ b/comfy/isolation/runtime_helpers.py @@ -0,0 +1,343 @@ +# pylint: disable=consider-using-from-import,import-outside-toplevel,no-member +from __future__ import annotations + +import copy +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Set, TYPE_CHECKING + +from .proxies.helper_proxies import restore_input_types +from comfy_api.internal import _ComfyNodeInternal +from comfy_api.latest import _io as latest_io +from .shm_forensics import scan_shm_forensics + +if TYPE_CHECKING: + from .extension_wrapper import ComfyNodeExtension + +LOG_PREFIX = "][" +_PRE_EXEC_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 + + +def _resource_snapshot() -> Dict[str, int]: + fd_count = -1 + shm_sender_files = 0 + try: + fd_count = len(os.listdir("/proc/self/fd")) + except Exception: + pass + try: + shm_root = Path("/dev/shm") + if shm_root.exists(): + prefix = f"torch_{os.getpid()}_" + shm_sender_files = sum(1 for _ in shm_root.glob(f"{prefix}*")) + except Exception: + pass + return {"fd_count": fd_count, "shm_sender_files": shm_sender_files} + + +def _tensor_transport_summary(value: Any) -> Dict[str, int]: + summary: Dict[str, int] = { + "tensor_count": 0, + "cpu_tensors": 0, + "cuda_tensors": 0, + "shared_cpu_tensors": 0, + "tensor_bytes": 0, + } + try: + import torch + except Exception: + return summary + + def visit(node: Any) -> None: + if isinstance(node, torch.Tensor): + summary["tensor_count"] += 1 + summary["tensor_bytes"] += int(node.numel() * node.element_size()) + if node.device.type == "cpu": + summary["cpu_tensors"] += 1 + if node.is_shared(): + summary["shared_cpu_tensors"] += 1 + elif node.device.type == "cuda": + summary["cuda_tensors"] += 1 + return + if isinstance(node, dict): + for v in node.values(): + visit(v) + return + if isinstance(node, (list, tuple)): + for v in node: + visit(v) + + visit(value) + return summary + + +def _extract_hidden_unique_id(inputs: Dict[str, Any]) -> str | None: + for key, value in inputs.items(): + key_text = str(key) + if "unique_id" in key_text: + return str(value) + return None + + +def _flush_tensor_transport_state(marker: str, logger: logging.Logger) -> None: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + return + if not callable(flush_tensor_keeper): + return + flushed = flush_tensor_keeper() + if flushed > 0: + logger.debug( + "%s %s flush_tensor_keeper released=%d", LOG_PREFIX, marker, flushed + ) + + +def _relieve_host_vram_pressure(marker: str, logger: logging.Logger) -> None: + import comfy.model_management as model_management + + model_management.cleanup_models_gc() + model_management.cleanup_models() + + device = model_management.get_torch_device() + if not hasattr(device, "type") or device.type == "cpu": + return + + required = max( + model_management.minimum_inference_memory(), + _PRE_EXEC_MIN_FREE_VRAM_BYTES, + ) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=True) + if model_management.get_free_memory(device) < required: + model_management.free_memory(required, device, for_dynamic=False) + model_management.cleanup_models() + model_management.soft_empty_cache() + logger.debug("%s %s free_memory target=%d", LOG_PREFIX, marker, required) + + +def _detach_shared_cpu_tensors(value: Any) -> Any: + try: + import torch + except Exception: + return value + + if isinstance(value, torch.Tensor): + if value.device.type == "cpu" and value.is_shared(): + clone = value.clone() + if value.requires_grad: + clone.requires_grad_(True) + return clone + return value + if isinstance(value, list): + return [_detach_shared_cpu_tensors(v) for v in value] + if isinstance(value, tuple): + return tuple(_detach_shared_cpu_tensors(v) for v in value) + if isinstance(value, dict): + return {k: _detach_shared_cpu_tensors(v) for k, v in value.items()} + return value + + +def build_stub_class( + node_name: str, + info: Dict[str, object], + extension: "ComfyNodeExtension", + running_extensions: Dict[str, "ComfyNodeExtension"], + logger: logging.Logger, +) -> type: + is_v3 = bool(info.get("is_v3", False)) + function_name = "_pyisolate_execute" + restored_input_types = restore_input_types(info.get("input_types", {})) + + async def _execute(self, **inputs): + from comfy.isolation import _RUNNING_EXTENSIONS + + # Update BOTH the local dict AND the module-level dict + running_extensions[extension.name] = extension + _RUNNING_EXTENSIONS[extension.name] = extension + prev_child = None + node_unique_id = _extract_hidden_unique_id(inputs) + summary = _tensor_transport_summary(inputs) + resources = _resource_snapshot() + logger.debug( + "%s ISO:execute_start ext=%s node=%s uid=%s tensors=%d cpu=%d cuda=%d shared_cpu=%d bytes=%d fds=%d sender_shm=%d", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + summary["tensor_count"], + summary["cpu_tensors"], + summary["cuda_tensors"], + summary["shared_cpu_tensors"], + summary["tensor_bytes"], + resources["fd_count"], + resources["shm_sender_files"], + ) + scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True) + try: + if os.environ.get("PYISOLATE_CHILD") != "1": + _relieve_host_vram_pressure("RUNTIME:pre_execute", logger) + scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True) + from pyisolate._internal.model_serialization import ( + serialize_for_isolation, + deserialize_from_isolation, + ) + + prev_child = os.environ.pop("PYISOLATE_CHILD", None) + logger.debug( + "%s ISO:serialize_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + serialized = serialize_for_isolation(inputs) + logger.debug( + "%s ISO:serialize_done ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + logger.debug( + "%s ISO:dispatch_start ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + result = await extension.execute_node(node_name, **serialized) + logger.debug( + "%s ISO:dispatch_done ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + deserialized = await deserialize_from_isolation(result, extension) + scan_shm_forensics("RUNTIME:post_execute", refresh_model_context=True) + return _detach_shared_cpu_tensors(deserialized) + except ImportError: + return await extension.execute_node(node_name, **inputs) + except Exception: + logger.exception( + "%s ISO:execute_error ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + raise + finally: + if prev_child is not None: + os.environ["PYISOLATE_CHILD"] = prev_child + logger.debug( + "%s ISO:execute_end ext=%s node=%s uid=%s", + LOG_PREFIX, + extension.name, + node_name, + node_unique_id or "-", + ) + scan_shm_forensics("RUNTIME:execute_end", refresh_model_context=True) + + def _input_types( + cls, + include_hidden: bool = True, + return_schema: bool = False, + live_inputs: Any = None, + ): + if not is_v3: + return restored_input_types + + inputs_copy = copy.deepcopy(restored_input_types) + if not include_hidden: + inputs_copy.pop("hidden", None) + + v3_data: Dict[str, Any] = {"hidden_inputs": {}} + dynamic = inputs_copy.pop("dynamic_paths", None) + if dynamic is not None: + v3_data["dynamic_paths"] = dynamic + + if return_schema: + hidden_vals = info.get("hidden", []) or [] + hidden_enums = [] + for h in hidden_vals: + try: + hidden_enums.append(latest_io.Hidden(h)) + except Exception: + hidden_enums.append(h) + + class SchemaProxy: + hidden = hidden_enums + + return inputs_copy, SchemaProxy, v3_data + return inputs_copy + + def _validate_class(cls): + return True + + def _get_node_info_v1(cls): + return info.get("schema_v1", {}) + + def _get_base_class(cls): + return latest_io.ComfyNode + + attributes: Dict[str, object] = { + "FUNCTION": function_name, + "CATEGORY": info.get("category", ""), + "OUTPUT_NODE": info.get("output_node", False), + "RETURN_TYPES": tuple(info.get("return_types", ()) or ()), + "RETURN_NAMES": info.get("return_names"), + function_name: _execute, + "_pyisolate_extension": extension, + "_pyisolate_node_name": node_name, + "INPUT_TYPES": classmethod(_input_types), + } + + output_is_list = info.get("output_is_list") + if output_is_list is not None: + attributes["OUTPUT_IS_LIST"] = tuple(output_is_list) + + if is_v3: + attributes["VALIDATE_CLASS"] = classmethod(_validate_class) + attributes["GET_NODE_INFO_V1"] = classmethod(_get_node_info_v1) + attributes["GET_BASE_CLASS"] = classmethod(_get_base_class) + attributes["DESCRIPTION"] = info.get("description", "") + attributes["EXPERIMENTAL"] = info.get("experimental", False) + attributes["DEPRECATED"] = info.get("deprecated", False) + attributes["API_NODE"] = info.get("api_node", False) + attributes["NOT_IDEMPOTENT"] = info.get("not_idempotent", False) + attributes["INPUT_IS_LIST"] = info.get("input_is_list", False) + + class_name = f"PyIsolate_{node_name}".replace(" ", "_") + bases = (_ComfyNodeInternal,) if is_v3 else () + stub_cls = type(class_name, bases, attributes) + + if is_v3: + try: + stub_cls.VALIDATE_CLASS() + except Exception as e: + logger.error("%s VALIDATE_CLASS failed: %s - %s", LOG_PREFIX, node_name, e) + + return stub_cls + + +def get_class_types_for_extension( + extension_name: str, + running_extensions: Dict[str, "ComfyNodeExtension"], + specs: List[Any], +) -> Set[str]: + extension = running_extensions.get(extension_name) + if not extension: + return set() + + ext_path = Path(extension.module_path) + class_types = set() + for spec in specs: + if spec.module_path.resolve() == ext_path.resolve(): + class_types.add(spec.node_name) + return class_types + + +__all__ = ["build_stub_class", "get_class_types_for_extension"] diff --git a/comfy/isolation/shm_forensics.py b/comfy/isolation/shm_forensics.py new file mode 100644 index 000000000..36223505a --- /dev/null +++ b/comfy/isolation/shm_forensics.py @@ -0,0 +1,217 @@ +# pylint: disable=consider-using-from-import,import-outside-toplevel +from __future__ import annotations + +import atexit +import hashlib +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Set + +LOG_PREFIX = "][" +logger = logging.getLogger(__name__) + + +def _shm_debug_enabled() -> bool: + return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1" + + +class _SHMForensicsTracker: + def __init__(self) -> None: + self._started = False + self._tracked_files: Set[str] = set() + self._current_model_context: Dict[str, str] = { + "id": "unknown", + "name": "unknown", + "hash": "????", + } + + @staticmethod + def _snapshot_shm() -> Set[str]: + shm_path = Path("/dev/shm") + if not shm_path.exists(): + return set() + return {f.name for f in shm_path.glob("torch_*")} + + def start(self) -> None: + if self._started or not _shm_debug_enabled(): + return + self._tracked_files = self._snapshot_shm() + self._started = True + logger.debug( + "%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files) + ) + + def stop(self) -> None: + if not self._started: + return + self.scan("shutdown", refresh_model_context=True) + self._started = False + logger.debug("%s SHM:forensics_disabled", LOG_PREFIX) + + def _compute_model_hash(self, model_patcher: Any) -> str: + try: + model_instance_id = getattr(model_patcher, "_instance_id", None) + if model_instance_id is not None: + model_id_text = str(model_instance_id) + return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text + + import torch + + real_model = ( + model_patcher.model + if hasattr(model_patcher, "model") + else model_patcher + ) + tensor = None + if hasattr(real_model, "parameters"): + for p in real_model.parameters(): + if torch.is_tensor(p) and p.numel() > 0: + tensor = p + break + + if tensor is None: + return "0000" + + flat = tensor.flatten() + values = [] + indices = [0, flat.shape[0] // 2, flat.shape[0] - 1] + for i in indices: + if i < flat.shape[0]: + values.append(flat[i].item()) + + size = 0 + if hasattr(model_patcher, "model_size"): + size = model_patcher.model_size() + sample_str = f"{values}_{id(model_patcher):016x}_{size}" + return hashlib.sha256(sample_str.encode()).hexdigest()[-4:] + except Exception: + return "err!" + + def _get_models_snapshot(self) -> List[Dict[str, Any]]: + try: + import comfy.model_management as model_management + except Exception: + return [] + + snapshot: List[Dict[str, Any]] = [] + try: + for loaded_model in model_management.current_loaded_models: + model = loaded_model.model + if model is None: + continue + if str(getattr(loaded_model, "device", "")) != "cuda:0": + continue + + name = ( + model.model.__class__.__name__ + if hasattr(model, "model") + else type(model).__name__ + ) + model_hash = self._compute_model_hash(model) + model_instance_id = getattr(model, "_instance_id", None) + if model_instance_id is None: + model_instance_id = model_hash + snapshot.append( + { + "name": str(name), + "id": str(model_instance_id), + "hash": str(model_hash or "????"), + "used": bool(getattr(loaded_model, "currently_used", False)), + } + ) + except Exception: + return [] + + return snapshot + + def _update_model_context(self) -> None: + snapshot = self._get_models_snapshot() + selected = None + + used_models = [m for m in snapshot if m.get("used") and m.get("id")] + if used_models: + selected = used_models[-1] + else: + live_models = [m for m in snapshot if m.get("id")] + if live_models: + selected = live_models[-1] + + if selected is None: + self._current_model_context = { + "id": "unknown", + "name": "unknown", + "hash": "????", + } + return + + self._current_model_context = { + "id": str(selected.get("id", "unknown")), + "name": str(selected.get("name", "unknown")), + "hash": str(selected.get("hash", "????") or "????"), + } + + def scan(self, marker: str, refresh_model_context: bool = True) -> None: + if not self._started or not _shm_debug_enabled(): + return + + if refresh_model_context: + self._update_model_context() + + current = self._snapshot_shm() + added = current - self._tracked_files + removed = self._tracked_files - current + self._tracked_files = current + + if not added and not removed: + logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker) + return + + for filename in sorted(added): + logger.info("%s SHM:created | %s", LOG_PREFIX, filename) + model_id = self._current_model_context["id"] + if model_id == "unknown": + logger.error( + "%s SHM:model_association_missing | file=%s | reason=no_active_model_context", + LOG_PREFIX, + filename, + ) + else: + logger.info( + "%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s", + LOG_PREFIX, + model_id, + filename, + self._current_model_context["name"], + self._current_model_context["hash"], + ) + + for filename in sorted(removed): + logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename) + + logger.debug( + "%s SHM:scan marker=%s created=%d deleted=%d active=%d", + LOG_PREFIX, + marker, + len(added), + len(removed), + len(self._tracked_files), + ) + + +_TRACKER = _SHMForensicsTracker() + + +def start_shm_forensics() -> None: + _TRACKER.start() + + +def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None: + _TRACKER.scan(marker, refresh_model_context=refresh_model_context) + + +def stop_shm_forensics() -> None: + _TRACKER.stop() + + +atexit.register(stop_shm_forensics) diff --git a/comfy/isolation/vae_proxy.py b/comfy/isolation/vae_proxy.py new file mode 100644 index 000000000..8260d06a3 --- /dev/null +++ b/comfy/isolation/vae_proxy.py @@ -0,0 +1,214 @@ +# pylint: disable=attribute-defined-outside-init +import logging +from typing import Any + +from comfy.isolation.proxies.base import ( + IS_CHILD_PROCESS, + BaseProxy, + BaseRegistry, + detach_if_grad, +) +from comfy.isolation.model_patcher_proxy import ModelPatcherProxy, ModelPatcherRegistry + +logger = logging.getLogger(__name__) + + +class FirstStageModelRegistry(BaseRegistry[Any]): + _type_prefix = "first_stage_model" + + async def get_property(self, instance_id: str, name: str) -> Any: + obj = self._get_instance(instance_id) + return getattr(obj, name) + + async def has_property(self, instance_id: str, name: str) -> bool: + obj = self._get_instance(instance_id) + return hasattr(obj, name) + + +class FirstStageModelProxy(BaseProxy[FirstStageModelRegistry]): + _registry_class = FirstStageModelRegistry + __module__ = "comfy.ldm.models.autoencoder" + + def __getattr__(self, name: str) -> Any: + try: + return self._call_rpc("get_property", name) + except Exception as e: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) from e + + def __repr__(self) -> str: + return f"" + + +class VAERegistry(BaseRegistry[Any]): + _type_prefix = "vae" + + async def get_patcher_id(self, instance_id: str) -> str: + vae = self._get_instance(instance_id) + return ModelPatcherRegistry().register(vae.patcher) + + async def get_first_stage_model_id(self, instance_id: str) -> str: + vae = self._get_instance(instance_id) + return FirstStageModelRegistry().register(vae.first_stage_model) + + async def encode(self, instance_id: str, pixels: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).encode(pixels)) + + async def encode_tiled( + self, + instance_id: str, + pixels: Any, + tile_x: int = 512, + tile_y: int = 512, + overlap: int = 64, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).encode_tiled( + pixels, tile_x=tile_x, tile_y=tile_y, overlap=overlap + ) + ) + + async def decode(self, instance_id: str, samples: Any, **kwargs: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).decode(samples, **kwargs)) + + async def decode_tiled( + self, + instance_id: str, + samples: Any, + tile_x: int = 64, + tile_y: int = 64, + overlap: int = 16, + **kwargs: Any, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).decode_tiled( + samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap, **kwargs + ) + ) + + async def get_property(self, instance_id: str, name: str) -> Any: + return getattr(self._get_instance(instance_id), name) + + async def memory_used_encode(self, instance_id: str, shape: Any, dtype: Any) -> int: + return self._get_instance(instance_id).memory_used_encode(shape, dtype) + + async def memory_used_decode(self, instance_id: str, shape: Any, dtype: Any) -> int: + return self._get_instance(instance_id).memory_used_decode(shape, dtype) + + async def process_input(self, instance_id: str, image: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).process_input(image)) + + async def process_output(self, instance_id: str, image: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).process_output(image)) + + +class VAEProxy(BaseProxy[VAERegistry]): + _registry_class = VAERegistry + __module__ = "comfy.sd" + + @property + def patcher(self) -> ModelPatcherProxy: + if not hasattr(self, "_patcher_proxy"): + patcher_id = self._call_rpc("get_patcher_id") + self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False) + return self._patcher_proxy + + @property + def first_stage_model(self) -> FirstStageModelProxy: + if not hasattr(self, "_first_stage_model_proxy"): + fsm_id = self._call_rpc("get_first_stage_model_id") + self._first_stage_model_proxy = FirstStageModelProxy( + fsm_id, manage_lifecycle=False + ) + return self._first_stage_model_proxy + + @property + def vae_dtype(self) -> Any: + return self._get_property("vae_dtype") + + def encode(self, pixels: Any) -> Any: + return self._call_rpc("encode", pixels) + + def encode_tiled( + self, pixels: Any, tile_x: int = 512, tile_y: int = 512, overlap: int = 64 + ) -> Any: + return self._call_rpc("encode_tiled", pixels, tile_x, tile_y, overlap) + + def decode(self, samples: Any, **kwargs: Any) -> Any: + return self._call_rpc("decode", samples, **kwargs) + + def decode_tiled( + self, + samples: Any, + tile_x: int = 64, + tile_y: int = 64, + overlap: int = 16, + **kwargs: Any, + ) -> Any: + return self._call_rpc( + "decode_tiled", samples, tile_x, tile_y, overlap, **kwargs + ) + + def get_sd(self) -> Any: + return self._call_rpc("get_sd") + + def _get_property(self, name: str) -> Any: + return self._call_rpc("get_property", name) + + @property + def latent_dim(self) -> int: + return self._get_property("latent_dim") + + @property + def latent_channels(self) -> int: + return self._get_property("latent_channels") + + @property + def downscale_ratio(self) -> Any: + return self._get_property("downscale_ratio") + + @property + def upscale_ratio(self) -> Any: + return self._get_property("upscale_ratio") + + @property + def output_channels(self) -> int: + return self._get_property("output_channels") + + @property + def check_not_vide(self) -> bool: + return self._get_property("not_video") + + @property + def device(self) -> Any: + return self._get_property("device") + + @property + def working_dtypes(self) -> Any: + return self._get_property("working_dtypes") + + @property + def disable_offload(self) -> bool: + return self._get_property("disable_offload") + + @property + def size(self) -> Any: + return self._get_property("size") + + def memory_used_encode(self, shape: Any, dtype: Any) -> int: + return self._call_rpc("memory_used_encode", shape, dtype) + + def memory_used_decode(self, shape: Any, dtype: Any) -> int: + return self._call_rpc("memory_used_decode", shape, dtype) + + def process_input(self, image: Any) -> Any: + return self._call_rpc("process_input", image) + + def process_output(self, image: Any) -> Any: + return self._call_rpc("process_output", image) + + +if not IS_CHILD_PROCESS: + _VAE_REGISTRY_SINGLETON = VAERegistry() + _FIRST_STAGE_MODEL_REGISTRY_SINGLETON = FirstStageModelRegistry() diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6978eb717..4ed4a9250 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,4 +1,5 @@ import math +import os from functools import partial from scipy import integrate @@ -12,8 +13,8 @@ from . import deis from . import sa_solver import comfy.model_patcher import comfy.model_sampling - import comfy.memory_management +from comfy.cli_args import args from comfy.utils import model_trange as trange def append_zero(x): @@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + if isolation_active: + target_device = sigmas.device + if x.device != target_device: + x = x.to(target_device) + s_in = s_in.to(target_device) + for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. diff --git a/comfy/model_base.py b/comfy/model_base.py index d9d5a9293..4f18adada 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1 import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging +import os import comfy.ldm.lightricks.av_model from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC @@ -112,8 +113,20 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.IMG_TO_IMG_FLOW: c = comfy.model_sampling.IMG_TO_IMG_FLOW + from comfy.cli_args import args + isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + class ModelSampling(s, c): - pass + if isolation_runtime_enabled: + def __reduce__(self): + """Ensure pickling yields a proxy instead of failing on local class.""" + try: + from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy + registry = ModelSamplingRegistry() + ms_id = registry.register(self) + return (ModelSamplingProxy, (ms_id,)) + except Exception as exc: + raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc return ModelSampling(model_config) diff --git a/comfy/model_management.py b/comfy/model_management.py index 81c89b180..8ed3f8a88 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -372,7 +372,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0] + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD @@ -400,7 +400,7 @@ try: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): if any((a in arch) for a in ["gfx1200", "gfx1201"]): @@ -497,6 +497,9 @@ except: current_loaded_models = [] +def _isolation_mode_enabled(): + return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + def module_size(module): module_mem = 0 sd = module.state_dict() @@ -576,8 +579,9 @@ class LoadedModel: if freed >= memory_to_free: return False self.model.detach(unpatch_weights) - self.model_finalizer.detach() - self.model_finalizer = None + if self.model_finalizer is not None: + self.model_finalizer.detach() + self.model_finalizer = None self.real_model = None return True @@ -591,8 +595,15 @@ class LoadedModel: if self._patcher_finalizer is not None: self._patcher_finalizer.detach() + def dead_state(self): + model_ref_gone = self.model is None + real_model_ref = self.real_model + real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None + return model_ref_gone, real_model_ref_gone + def is_dead(self): - return self.real_model() is not None and self.model is None + model_ref_gone, real_model_ref_gone = self.dead_state() + return model_ref_gone or real_model_ref_gone def use_more_memory(extra_memory, loaded_models, device): @@ -638,6 +649,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ unloaded_model = [] can_unload = [] unloaded_models = [] + isolation_active = _isolation_mode_enabled() for i in range(len(current_loaded_models) -1, -1, -1): shift_model = current_loaded_models[i] @@ -646,6 +658,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) shift_model.currently_used = False + if can_unload and isolation_active: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + flush_tensor_keeper = None + if callable(flush_tensor_keeper): + flushed = flush_tensor_keeper() + if flushed > 0: + logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed) + gc.collect() + for x in sorted(can_unload): i = x[-1] memory_to_free = 1e32 @@ -666,7 +689,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ current_loaded_models[i].model.partially_unload_ram(ram_to_free) for i in sorted(unloaded_model, reverse=True): - unloaded_models.append(current_loaded_models.pop(i)) + unloaded = current_loaded_models.pop(i) + model_obj = unloaded.model + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() + unloaded_models.append(unloaded) if len(unloaded_model) > 0: soft_empty_cache() @@ -725,7 +754,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu for i in to_unload: model_to_unload = current_loaded_models.pop(i) model_to_unload.model.detach(unpatch_all=False) - model_to_unload.model_finalizer.detach() + if model_to_unload.model_finalizer is not None: + model_to_unload.model_finalizer.detach() + model_to_unload.model_finalizer = None total_memory_required = {} @@ -788,25 +819,62 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): - do_gc = False - reset_cast_buffers() + if not _isolation_mode_enabled(): + dead_found = False + for i in range(len(current_loaded_models)): + if current_loaded_models[i].is_dead(): + dead_found = True + break + if dead_found: + logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") + gc.collect() + soft_empty_cache() + + for i in range(len(current_loaded_models) - 1, -1, -1): + cur = current_loaded_models[i] + if cur.is_dead(): + logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") + leaked = current_loaded_models.pop(i) + model_obj = getattr(leaked, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() + return + + dead_found = False + has_real_model_leak = False for i in range(len(current_loaded_models)): - cur = current_loaded_models[i] - if cur.is_dead(): - logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__)) - do_gc = True - break + model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state() + if model_ref_gone or real_model_ref_gone: + dead_found = True + if real_model_ref_gone and not model_ref_gone: + has_real_model_leak = True - if do_gc: + if dead_found: + if has_real_model_leak: + logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") + else: + logging.debug("Cleaning stale loaded-model entries with released patcher references.") gc.collect() soft_empty_cache() - for i in range(len(current_loaded_models)): + for i in range(len(current_loaded_models) - 1, -1, -1): cur = current_loaded_models[i] - if cur.is_dead(): - logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) + model_ref_gone, real_model_ref_gone = cur.dead_state() + if model_ref_gone or real_model_ref_gone: + if real_model_ref_gone and not model_ref_gone: + logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") + else: + logging.debug("Cleaning stale loaded-model entry with released patcher reference.") + leaked = current_loaded_models.pop(i) + model_obj = getattr(leaked, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() def archive_model_dtypes(model): @@ -820,11 +888,20 @@ def archive_model_dtypes(model): def cleanup_models(): to_delete = [] for i in range(len(current_loaded_models)): - if current_loaded_models[i].real_model() is None: + real_model_ref = current_loaded_models[i].real_model + if real_model_ref is None: + to_delete = [i] + to_delete + continue + if callable(real_model_ref) and real_model_ref() is None: to_delete = [i] + to_delete for i in to_delete: x = current_loaded_models.pop(i) + model_obj = getattr(x, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() del x def dtype_size(dtype): diff --git a/comfy/samplers.py b/comfy/samplers.py index 8be449ef7..b79ac575b 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -11,12 +11,14 @@ from functools import partial import collections import math import logging +import os import comfy.sampler_helpers import comfy.model_patcher import comfy.patcher_extension import comfy.hooks import comfy.context_windows import comfy.utils +from comfy.cli_args import args import scipy.stats import numpy @@ -210,9 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc _calc_cond_batch, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) ) - return executor.execute(model, conds, x_in, timestep, model_options) + result = executor.execute(model, conds, x_in, timestep, model_options) + return result def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" out_conds = [] out_counts = [] # separate conds by matching hooks @@ -269,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for k, v in to_run[tt][0].conditioning.items(): cond_shapes[k].append(v.size()) - if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: + memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes) + if memory_required * 1.5 < free_memory: to_batch = batch_amount break @@ -294,9 +299,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens patches = p.patches batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) + if isolation_active: + target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device + input_x = torch.cat(input_x).to(target_device) + else: + input_x = torch.cat(input_x) c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) + if isolation_active: + timestep_ = torch.cat([timestep] * batch_chunks).to(target_device) + mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult] + else: + timestep_ = torch.cat([timestep] * batch_chunks) transformer_options = model.current_patcher.apply_hooks(hooks=hooks) if 'transformer_options' in model_options: @@ -327,9 +340,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for o in range(batch_chunks): cond_index = cond_or_uncond[o] a = area[o] + out_t = output[o] + mult_t = mult[o] + if isolation_active: + target_dev = out_conds[cond_index].device + if hasattr(out_t, "device") and out_t.device != target_dev: + out_t = out_t.to(target_dev) + if hasattr(mult_t, "device") and mult_t.device != target_dev: + mult_t = mult_t.to(target_dev) if a is None: - out_conds[cond_index] += output[o] * mult[o] - out_counts[cond_index] += mult[o] + out_conds[cond_index] += out_t * mult_t + out_counts[cond_index] += mult_t else: out_c = out_conds[cond_index] out_cts = out_counts[cond_index] @@ -337,8 +358,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for i in range(dims): out_c = out_c.narrow(i + 2, a[i + dims], a[i]) out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) - out_c += output[o] * mult[o] - out_cts += mult[o] + out_c += out_t * mult_t + out_cts += mult_t for i in range(len(out_conds)): out_conds[i] /= out_counts[i] @@ -392,14 +413,31 @@ class KSamplerX0Inpaint: self.inner_model = model self.sigmas = sigmas def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" if denoise_mask is not None: + if isolation_active and denoise_mask.device != x.device: + denoise_mask = denoise_mask.to(x.device) if "denoise_mask_function" in model_options: denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask - x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask + if isolation_active: + latent_image = self.latent_image + if hasattr(latent_image, "device") and latent_image.device != x.device: + latent_image = latent_image.to(x.device) + scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image) + if hasattr(scaled, "device") and scaled.device != x.device: + scaled = scaled.to(x.device) + else: + scaled = self.inner_model.inner_model.scale_latent_inpaint( + x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image + ) + x = x * denoise_mask + scaled * latent_mask out = self.inner_model(x, sigma, model_options=model_options, seed=seed) if denoise_mask is not None: - out = out * denoise_mask + self.latent_image * latent_mask + latent_image = self.latent_image + if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device: + latent_image = latent_image.to(out.device) + out = out * denoise_mask + latent_image * latent_mask return out def simple_scheduler(model_sampling, steps): @@ -741,7 +779,11 @@ class KSAMPLER(Sampler): else: model_k.noise = noise - noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)) + max_denoise = self.max_denoise(model_wrap, sigmas) + model_sampling = model_wrap.inner_model.model_sampling + noise = model_sampling.noise_scaling( + sigmas[0], noise, latent_image, max_denoise + ) k_callback = None total_steps = len(sigmas) - 1 diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 050031dc0..3bb85b5c3 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL, SVG as _SVG, File3D +from ._util import MESH, VOXEL, SVG as _SVG, File3D, PLY as _PLY, NPZ as _NPZ class FolderType(str, Enum): @@ -678,6 +678,16 @@ class Mesh(ComfyTypeIO): Type = MESH +@comfytype(io_type="PLY") +class Ply(ComfyTypeIO): + Type = _PLY + + +@comfytype(io_type="NPZ") +class Npz(ComfyTypeIO): + Type = _NPZ + + @comfytype(io_type="FILE_3D") class File3DAny(ComfyTypeIO): """General 3D file type - accepts any supported 3D format.""" @@ -2197,6 +2207,8 @@ __all__ = [ "LossMap", "Voxel", "Mesh", + "Ply", + "Npz", "File3DAny", "File3DGLB", "File3DGLTF", diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index 115baf392..7d9ca337b 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,6 +1,8 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents from .geometry_types import VOXEL, MESH, File3D from .image_types import SVG +from .ply_types import PLY +from .npz_types import NPZ __all__ = [ # Utility Types @@ -11,4 +13,6 @@ __all__ = [ "MESH", "File3D", "SVG", + "PLY", + "NPZ", ] diff --git a/comfy_api/latest/_util/npz_types.py b/comfy_api/latest/_util/npz_types.py new file mode 100644 index 000000000..a93eed68c --- /dev/null +++ b/comfy_api/latest/_util/npz_types.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import os + + +class NPZ: + """Ordered collection of NPZ file payloads. + + Each entry in ``frames`` is a complete compressed ``.npz`` file stored + as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO). + ``save_to`` writes numbered files into a directory. + """ + + def __init__(self, frames: list[bytes]) -> None: + self.frames = frames + + @property + def num_frames(self) -> int: + return len(self.frames) + + def save_to(self, directory: str, prefix: str = "frame") -> str: + os.makedirs(directory, exist_ok=True) + for i, frame_bytes in enumerate(self.frames): + path = os.path.join(directory, f"{prefix}_{i:06d}.npz") + with open(path, "wb") as f: + f.write(frame_bytes) + return directory diff --git a/comfy_api/latest/_util/ply_types.py b/comfy_api/latest/_util/ply_types.py new file mode 100644 index 000000000..8beb566bc --- /dev/null +++ b/comfy_api/latest/_util/ply_types.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import numpy as np + + +class PLY: + """Point cloud payload for PLY file output. + + Supports two schemas: + - Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format) + - Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format) + + When ``raw_data`` is provided, the object acts as an opaque binary PLY + carrier and ``save_to`` writes the bytes directly. + """ + + def __init__( + self, + points: np.ndarray | None = None, + colors: np.ndarray | None = None, + confidence: np.ndarray | None = None, + view_id: np.ndarray | None = None, + raw_data: bytes | None = None, + ) -> None: + self.raw_data = raw_data + if raw_data is not None: + self.points = None + self.colors = None + self.confidence = None + self.view_id = None + return + if points is None: + raise ValueError("Either points or raw_data must be provided") + if points.ndim != 2 or points.shape[1] != 3: + raise ValueError(f"points must be (N, 3), got {points.shape}") + self.points = np.ascontiguousarray(points, dtype=np.float32) + self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None + self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None + self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None + + @property + def is_gaussian(self) -> bool: + return self.raw_data is not None + + @property + def num_points(self) -> int: + if self.points is not None: + return self.points.shape[0] + return 0 + + @staticmethod + def _to_numpy(arr, dtype): + if arr is None: + return None + if hasattr(arr, "numpy"): + arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy() + return np.ascontiguousarray(arr, dtype=dtype) + + def save_to(self, path: str) -> str: + if self.raw_data is not None: + with open(path, "wb") as f: + f.write(self.raw_data) + return path + self.points = self._to_numpy(self.points, np.float32) + self.colors = self._to_numpy(self.colors, np.float32) + self.confidence = self._to_numpy(self.confidence, np.float32) + self.view_id = self._to_numpy(self.view_id, np.int32) + N = self.num_points + header_lines = [ + "ply", + "format ascii 1.0", + f"element vertex {N}", + "property float x", + "property float y", + "property float z", + ] + if self.colors is not None: + header_lines += ["property uchar red", "property uchar green", "property uchar blue"] + if self.confidence is not None: + header_lines.append("property float confidence") + if self.view_id is not None: + header_lines.append("property int view_id") + header_lines.append("end_header") + + with open(path, "w") as f: + f.write("\n".join(header_lines) + "\n") + for i in range(N): + parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"] + if self.colors is not None: + r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8) + parts.append(f"{r} {g} {b}") + if self.confidence is not None: + parts.append(f"{self.confidence[i]}") + if self.view_id is not None: + parts.append(f"{int(self.view_id[i])}") + f.write(" ".join(parts) + "\n") + return path diff --git a/comfy_extras/nodes_save_npz.py b/comfy_extras/nodes_save_npz.py new file mode 100644 index 000000000..756a01907 --- /dev/null +++ b/comfy_extras/nodes_save_npz.py @@ -0,0 +1,40 @@ +import os + +import folder_paths +from comfy_api.latest import io +from comfy_api.latest._util.npz_types import NPZ + + +class SaveNPZ(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveNPZ", + display_name="Save NPZ", + category="3d", + is_output_node=True, + inputs=[ + io.Npz.Input("npz"), + io.String.Input("filename_prefix", default="da3_streaming/ComfyUI"), + ], + ) + + @classmethod + def execute(cls, npz: NPZ, filename_prefix: str) -> io.NodeOutput: + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, folder_paths.get_output_directory() + ) + batch_dir = os.path.join(full_output_folder, f"{filename}_{counter:05}") + os.makedirs(batch_dir, exist_ok=True) + filenames = [] + for i, frame_bytes in enumerate(npz.frames): + f = f"frame_{i:06d}.npz" + with open(os.path.join(batch_dir, f), "wb") as fh: + fh.write(frame_bytes) + filenames.append(f) + return io.NodeOutput(ui={"npz_files": [{"folder": os.path.join(subfolder, f"{filename}_{counter:05}"), "count": len(filenames), "type": "output"}]}) + + +NODE_CLASS_MAPPINGS = { + "SaveNPZ": SaveNPZ, +} diff --git a/comfy_extras/nodes_save_ply.py b/comfy_extras/nodes_save_ply.py new file mode 100644 index 000000000..6ee7a67e9 --- /dev/null +++ b/comfy_extras/nodes_save_ply.py @@ -0,0 +1,34 @@ +import os + +import folder_paths +from comfy_api.latest import io +from comfy_api.latest._util.ply_types import PLY + + +class SavePLY(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SavePLY", + display_name="Save PLY", + category="3d", + is_output_node=True, + inputs=[ + io.Ply.Input("ply"), + io.String.Input("filename_prefix", default="pointcloud/ComfyUI"), + ], + ) + + @classmethod + def execute(cls, ply: PLY, filename_prefix: str) -> io.NodeOutput: + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, folder_paths.get_output_directory() + ) + f = f"{filename}_{counter:05}_.ply" + ply.save_to(os.path.join(full_output_folder, f)) + return io.NodeOutput(ui={"pointclouds": [{"filename": f, "subfolder": subfolder, "type": "output"}]}) + + +NODE_CLASS_MAPPINGS = { + "SavePLY": SavePLY, +} diff --git a/cuda_malloc.py b/cuda_malloc.py index f7651981c..f6d2063e9 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -92,7 +92,7 @@ if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" - else: + elif not args.use_process_isolation: env_var += ",backend:cudaMallocAsync" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var diff --git a/execution.py b/execution.py index a7791efed..d827ab3d1 100644 --- a/execution.py +++ b/execution.py @@ -1,7 +1,9 @@ import copy +import gc import heapq import inspect import logging +import os import sys import threading import time @@ -41,6 +43,8 @@ from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io +_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = False + class ExecutionResult(Enum): SUCCESS = 0 @@ -261,20 +265,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f pre_execute_cb(index) # V3 if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): - # if is just a class, then assign no state, just create clone - if is_class(obj): - type_obj = obj - obj.VALIDATE_CLASS() - class_clone = obj.PREPARE_CLASS_CLONE(v3_data) - # otherwise, use class instance to populate/reuse some fields + # Check for isolated node - skip validation and class cloning + if hasattr(obj, "_pyisolate_extension"): + # Isolated Node: The stub is just a proxy; real validation happens in child process + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) + # Inject hidden inputs so they're available in the isolated child process + inputs.update(v3_data.get("hidden_inputs", {})) + f = getattr(obj, func) + # Standard V3 Node (Existing Logic) + else: - type_obj = type(obj) - type_obj.VALIDATE_CLASS() - class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) - f = make_locked_method_func(type_obj, func, class_clone) - # in case of dynamic inputs, restructure inputs to expected nested dict - if v3_data is not None: - inputs = _io.build_nested_inputs(inputs, v3_data) + # if is just a class, then assign no resources or state, just create clone + if is_class(obj): + type_obj = obj + obj.VALIDATE_CLASS() + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) + # otherwise, use class instance to populate/reuse some fields + else: + type_obj = type(obj) + type_obj.VALIDATE_CLASS() + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) + f = make_locked_method_func(type_obj, func, class_clone) + # in case of dynamic inputs, restructure inputs to expected nested dict + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) # V1 else: f = getattr(obj, func) @@ -527,7 +542,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if args.verbose == "DEBUG": comfy_aimdo.control.analyze() comfy.model_management.reset_cast_buffers() - comfy_aimdo.model_vbar.vbars_reset_watermark_limits() + vbar_lib = getattr(comfy_aimdo.model_vbar, "lib", None) + if vbar_lib is not None: + comfy_aimdo.model_vbar.vbars_reset_watermark_limits() + else: + global _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED + if not _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED: + logging.warning( + "DynamicVRAM backend unavailable for watermark reset; " + "skipping vbar reset for this process." + ) + _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = True if has_pending_tasks: pending_async_nodes[unique_id] = output_data @@ -536,6 +561,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, tasks = [x for x in output_data if isinstance(x, asyncio.Task)] await asyncio.gather(*tasks, return_exceptions=True) unblock() + + # Keep isolation node execution deterministic by default, but allow + # opt-out for diagnostics. + isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes") + if args.use_process_isolation and isolation_sequential: + await await_completion() + return await execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs) + asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: @@ -647,6 +680,46 @@ class PromptExecutor: self.status_messages = [] self.success = True + async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import notify_execution_graph + await notify_execution_graph(class_types) + except Exception: + if fail_loud: + raise + logging.debug("][ EX:notify_execution_graph failed", exc_info=True) + + async def _flush_running_extensions_transport_state_safe(self) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import flush_running_extensions_transport_state + await flush_running_extensions_transport_state() + except Exception: + logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True) + + async def _wait_model_patcher_quiescence_safe( + self, + *, + fail_loud: bool = False, + timeout_ms: int = 120000, + marker: str = "EX:wait_model_patcher_idle", + ) -> None: + if not args.use_process_isolation: + return + try: + from comfy.isolation import wait_for_model_patcher_quiescence + + await wait_for_model_patcher_quiescence( + timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker + ) + except Exception: + if fail_loud: + raise + logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True) + def add_message(self, event, data: dict, broadcast: bool): data = { **data, @@ -688,6 +761,18 @@ class PromptExecutor: asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + if args.use_process_isolation: + # Update RPC event loops for all isolated extensions. + # This is critical for serial workflow execution - each asyncio.run() creates + # a new event loop, and RPC instances must be updated to use it. + try: + from comfy.isolation import update_rpc_event_loops + update_rpc_event_loops() + except ImportError: + pass # Isolation not available + except Exception as e: + logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}") + set_preview_method(extra_data.get("preview_method")) nodes.interrupt_processing(False) @@ -701,6 +786,25 @@ class PromptExecutor: self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): + if args.use_process_isolation: + try: + # Boundary cleanup runs at the start of the next workflow in + # isolation mode, matching non-isolated "next prompt" timing. + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) + await self._wait_model_patcher_quiescence_safe( + fail_loud=False, + timeout_ms=120000, + marker="EX:boundary_cleanup_wait_idle", + ) + await self._flush_running_extensions_transport_state_safe() + comfy.model_management.unload_all_models() + comfy.model_management.cleanup_models_gc() + comfy.model_management.cleanup_models() + gc.collect() + comfy.model_management.soft_empty_cache() + except Exception: + logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True) + dynamic_prompt = DynamicPrompt(prompt) reset_progress_state(prompt_id, dynamic_prompt) add_progress_handler(WebUIProgressHandler(self.server)) @@ -727,6 +831,18 @@ class PromptExecutor: for node_id in list(execute_outputs): execution_list.add_node(node_id) + if args.use_process_isolation: + pending_class_types = set() + for node_id in execution_list.pendingNodes.keys(): + class_type = dynamic_prompt.get_node(node_id)["class_type"] + pending_class_types.add(class_type) + await self._wait_model_patcher_quiescence_safe( + fail_loud=True, + timeout_ms=120000, + marker="EX:notify_graph_wait_idle", + ) + await self._notify_execution_graph_safe(pending_class_types, fail_loud=True) + while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() if error is not None: @@ -757,6 +873,7 @@ class PromptExecutor: "outputs": ui_outputs, "meta": meta_outputs, } + comfy.model_management.cleanup_models_gc() self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() diff --git a/main.py b/main.py index 8905fd09a..fa6322c2b 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,21 @@ +import os +import sys + +IS_PYISOLATE_CHILD = os.environ.get("PYISOLATE_CHILD") == "1" + +if __name__ == "__main__" and IS_PYISOLATE_CHILD: + del os.environ["PYISOLATE_CHILD"] + IS_PYISOLATE_CHILD = False + +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +if CURRENT_DIR not in sys.path: + sys.path.insert(0, CURRENT_DIR) + +IS_PRIMARY_PROCESS = (not IS_PYISOLATE_CHILD) and __name__ == "__main__" + import comfy.options comfy.options.enable_args_parsing() -import os import importlib.util import shutil import importlib.metadata @@ -20,12 +34,45 @@ from comfy_execution.utils import get_executing_context from comfy_api import feature_flags from app.database.db import init_db, dependencies_available -if __name__ == "__main__": - #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. +import comfy_aimdo.control + +if enables_dynamic_vram(): + if not comfy_aimdo.control.init(): + logging.warning( + "DynamicVRAM requested, but comfy-aimdo failed to initialize early. " + "Will fall back to legacy model loading if device init fails." + ) + +if '--use-process-isolation' in sys.argv: + from comfy.isolation import initialize_proxies + initialize_proxies() + + # Explicitly register the ComfyUI adapter for pyisolate (v1.0 architecture) + try: + import pyisolate + from comfy.isolation.adapter import ComfyUIAdapter + pyisolate.register_adapter(ComfyUIAdapter()) + logging.info("PyIsolate adapter registered: comfyui") + except ImportError: + logging.warning("PyIsolate not installed or version too old for explicit registration") + except Exception as e: + logging.error(f"Failed to register PyIsolate adapter: {e}") + + if not IS_PYISOLATE_CHILD: + if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:native' + +if not IS_PYISOLATE_CHILD: + from comfy_execution.progress import get_progress_state + from comfy_execution.utils import get_executing_context + from comfy_api import feature_flags + +if IS_PRIMARY_PROCESS: os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' -setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +if not IS_PYISOLATE_CHILD: + setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) faulthandler.enable(file=sys.stderr, all_threads=False) @@ -91,14 +138,15 @@ if args.enable_manager: def apply_custom_paths(): + from utils import extra_config # Deferred import - spawn re-runs main.py # extra model paths extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): - utils.extra_config.load_extra_path_config(extra_model_paths_config_path) + extra_config.load_extra_path_config(extra_model_paths_config_path) if args.extra_model_paths_config: for config_path in itertools.chain(*args.extra_model_paths_config): - utils.extra_config.load_extra_path_config(config_path) + extra_config.load_extra_path_config(config_path) # --output-directory, --input-directory, --user-directory if args.output_directory: @@ -171,15 +219,17 @@ def execute_prestartup_script(): else: import_message = " (PRESTARTUP FAILED)" logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) - logging.info("") + logging.info("") -apply_custom_paths() -init_mime_types() +if not IS_PYISOLATE_CHILD: + apply_custom_paths() + init_mime_types() -if args.enable_manager: +if args.enable_manager and not IS_PYISOLATE_CHILD: comfyui_manager.prestartup() -execute_prestartup_script() +if not IS_PYISOLATE_CHILD: + execute_prestartup_script() # Main code @@ -190,18 +240,18 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") - import comfy.utils from app.assets.seeder import asset_seeder -import execution -import server -from protocol import BinaryEventTypes -import nodes -import comfy.model_management -import comfyui_version -import app.logger -import hook_breaker_ac10a0 +if not IS_PYISOLATE_CHILD: + import execution + import server + from protocol import BinaryEventTypes + import nodes + import comfy.model_management + import comfyui_version + import app.logger + import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher @@ -417,6 +467,10 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + if args.use_process_isolation: + from comfy.isolation import start_isolation_loading_early + start_isolation_loading_early(asyncio_loop) + if args.enable_manager and not args.disable_manager_ui: comfyui_manager.start() @@ -461,12 +515,13 @@ def start_comfyui(asyncio_loop=None): if __name__ == "__main__": # Running directly, just start ComfyUI. logging.info("Python version: {}".format(sys.version)) - logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) - for package in ("comfy-aimdo", "comfy-kitchen"): - try: - logging.info("{} version: {}".format(package, importlib.metadata.version(package))) - except: - pass + if not IS_PYISOLATE_CHILD: + logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + for package in ("comfy-aimdo", "comfy-kitchen"): + try: + logging.info("{} version: {}".format(package, importlib.metadata.version(package))) + except: + pass if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") diff --git a/nodes.py b/nodes.py index 0ef23b640..c06bdd469 100644 --- a/nodes.py +++ b/nodes.py @@ -1925,6 +1925,7 @@ class ImageInvert: class ImageBatch: SEARCH_ALIASES = ["combine images", "merge images", "stack images"] + ESSENTIALS_CATEGORY = "Image Tools" @classmethod def INPUT_TYPES(s): @@ -2306,6 +2307,27 @@ async def init_external_custom_nodes(): Returns: None """ + whitelist = set() + isolated_module_paths = set() + if args.use_process_isolation: + from pathlib import Path + from comfy.isolation import await_isolation_loading, get_claimed_paths + from comfy.isolation.host_policy import load_host_policy + + # Load Global Host Policy + host_policy = load_host_policy(Path(folder_paths.base_path)) + whitelist_dict = host_policy.get("whitelist", {}) + # Normalize whitelist keys to lowercase for case-insensitive matching + # (matches ComfyUI-Manager's normalization: project.name.strip().lower()) + whitelist = set(k.strip().lower() for k in whitelist_dict.keys()) + logging.info(f"][ Loaded Whitelist: {len(whitelist)} nodes allowed.") + + isolated_specs = await await_isolation_loading() + for spec in isolated_specs: + NODE_CLASS_MAPPINGS.setdefault(spec.node_name, spec.stub_class) + NODE_DISPLAY_NAME_MAPPINGS.setdefault(spec.node_name, spec.display_name) + isolated_module_paths = get_claimed_paths() + base_node_names = set(NODE_CLASS_MAPPINGS.keys()) node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] @@ -2329,6 +2351,16 @@ async def init_external_custom_nodes(): logging.info(f"Blocked by policy: {module_path}") continue + if args.use_process_isolation: + if Path(module_path).resolve() in isolated_module_paths: + continue + + # Tri-State Enforcement: If not Isolated (checked above), MUST be Whitelisted. + # Normalize to lowercase for case-insensitive matching (matches ComfyUI-Manager) + if possible_module.strip().lower() not in whitelist: + logging.warning(f"][ REJECTED: Node '{possible_module}' is blocked by security policy (not whitelisted/isolated).") + continue + time_before = time.perf_counter() success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes") node_import_times.append((time.perf_counter() - time_before, module_path, success)) @@ -2343,6 +2375,14 @@ async def init_external_custom_nodes(): logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) logging.info("") + if args.use_process_isolation: + from comfy.isolation import isolated_node_timings + if isolated_node_timings: + logging.info("\nImport times for isolated custom nodes:") + for timing, path, count in sorted(isolated_node_timings): + logging.info("{:6.1f} seconds: {} ({})".format(timing, path, count)) + logging.info("") + async def init_builtin_extra_nodes(): """ Initializes the built-in extra nodes in ComfyUI. @@ -2415,6 +2455,8 @@ async def init_builtin_extra_nodes(): "nodes_wan.py", "nodes_lotus.py", "nodes_hunyuan3d.py", + "nodes_save_ply.py", + "nodes_save_npz.py", "nodes_primitive.py", "nodes_cfg.py", "nodes_optimalsteps.py", @@ -2435,7 +2477,6 @@ async def init_builtin_extra_nodes(): "nodes_audio_encoder.py", "nodes_rope.py", "nodes_logic.py", - "nodes_resolution.py", "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", @@ -2443,7 +2484,6 @@ async def init_builtin_extra_nodes(): "nodes_zimage.py", "nodes_glsl.py", "nodes_lora_debug.py", - "nodes_textgen.py", "nodes_color.py", "nodes_toolkit.py", "nodes_replacements.py", diff --git a/pyproject.toml b/pyproject.toml index 753b219b3..61fd5d383 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,25 @@ homepage = "https://www.comfy.org/" repository = "https://github.com/comfyanonymous/ComfyUI" documentation = "https://docs.comfy.org/" +[tool.comfy.host] +allow_network = false +writable_paths = ["/dev/shm", "/tmp"] + +[tool.comfy.host.whitelist] +"ComfyUI-Crystools" = "*" +"ComfyUI-Florence2" = "*" +"ComfyUI-GGUF" = "*" +"ComfyUI-KJNodes" = "*" +"ComfyUI-LTXVideo" = "*" +"ComfyUI-Manager" = "*" +"comfyui-depthanythingv2" = "*" +"comfyui-kjnodes" = "*" +"comfyui-videohelpersuite" = "*" +"comfyui_controlnet_aux" = "*" +"rgthree-comfy" = "*" +"was-ns" = "*" +"websocket_image_save.py" = "*" + [tool.ruff] lint.select = [ "N805", # invalid-first-argument-name-for-method diff --git a/requirements.txt b/requirements.txt index 2272d121a..912829f8d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,3 +35,5 @@ pydantic~=2.0 pydantic-settings~=2.0 PyOpenGL glfw + +pyisolate==0.9.2 diff --git a/server.py b/server.py index 76904ebc9..4c4779934 100644 --- a/server.py +++ b/server.py @@ -3,7 +3,6 @@ import sys import asyncio import traceback import time - import nodes import folder_paths import execution @@ -196,6 +195,8 @@ def create_block_external_middleware(): class PromptServer(): def __init__(self, loop): PromptServer.instance = self + if loop is None: + loop = asyncio.get_event_loop() self.user_manager = UserManager() self.model_file_manager = ModelFileManager() diff --git a/tests/isolation/test_client_snapshot.py b/tests/isolation/test_client_snapshot.py new file mode 100644 index 000000000..c6f906813 --- /dev/null +++ b/tests/isolation/test_client_snapshot.py @@ -0,0 +1,122 @@ +"""Tests for pyisolate._internal.client import-time snapshot handling.""" + +import json +import os +import subprocess +import sys +from pathlib import Path + +import pytest + +# Paths needed for subprocess +PYISOLATE_ROOT = str(Path(__file__).parent.parent) +COMFYUI_ROOT = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") + +SCRIPT = """ +import json, sys +import pyisolate._internal.client # noqa: F401 # triggers snapshot logic +print(json.dumps(sys.path[:6])) +""" + + +def _run_client_process(env): + # Ensure subprocess can find pyisolate and ComfyUI + pythonpath_parts = [PYISOLATE_ROOT, COMFYUI_ROOT] + existing = env.get("PYTHONPATH", "") + if existing: + pythonpath_parts.append(existing) + env["PYTHONPATH"] = ":".join(pythonpath_parts) + + result = subprocess.run( # noqa: S603 + [sys.executable, "-c", SCRIPT], + capture_output=True, + text=True, + env=env, + check=True, + ) + stdout = result.stdout.strip().splitlines()[-1] + return json.loads(stdout) + + +@pytest.fixture() +def comfy_module_path(tmp_path): + comfy_root = tmp_path / "ComfyUI" + module_path = comfy_root / "custom_nodes" / "TestNode" + module_path.mkdir(parents=True) + return comfy_root, module_path + + +def test_snapshot_applied_and_comfy_root_prepend(tmp_path, comfy_module_path): + comfy_root, module_path = comfy_module_path + # Must include real ComfyUI path for utils validation to pass + host_paths = [COMFYUI_ROOT, "/host/lib1", "/host/lib2"] + snapshot = { + "sys_path": host_paths, + "sys_executable": sys.executable, + "sys_prefix": sys.prefix, + "environment": {}, + } + snapshot_path = tmp_path / "snapshot.json" + snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8") + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(snapshot_path), + "PYISOLATE_MODULE_PATH": str(module_path), + } + ) + + path_prefix = _run_client_process(env) + + # Current client behavior preserves the runtime bootstrap path order and + # keeps the resolved ComfyUI root available for imports. + assert COMFYUI_ROOT in path_prefix + # Module path should not override runtime root selection. + assert str(comfy_root) not in path_prefix + + +def test_missing_snapshot_file_does_not_crash(tmp_path, comfy_module_path): + _, module_path = comfy_module_path + missing_snapshot = tmp_path / "missing.json" + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(missing_snapshot), + "PYISOLATE_MODULE_PATH": str(module_path), + } + ) + + # Should not raise even though snapshot path is missing + paths = _run_client_process(env) + assert len(paths) > 0 + + +def test_no_comfy_root_when_module_path_absent(tmp_path): + # Must include real ComfyUI path for utils validation to pass + host_paths = [COMFYUI_ROOT, "/alpha", "/beta"] + snapshot = { + "sys_path": host_paths, + "sys_executable": sys.executable, + "sys_prefix": sys.prefix, + "environment": {}, + } + snapshot_path = tmp_path / "snapshot.json" + snapshot_path.write_text(json.dumps(snapshot), encoding="utf-8") + + env = os.environ.copy() + env.update( + { + "PYISOLATE_CHILD": "1", + "PYISOLATE_HOST_SNAPSHOT": str(snapshot_path), + } + ) + + paths = _run_client_process(env) + # Runtime path bootstrap keeps ComfyUI importability regardless of host + # snapshot extras. + assert COMFYUI_ROOT in paths + assert "/alpha" not in paths and "/beta" not in paths diff --git a/tests/isolation/test_cuda_wheels_and_env_flags.py b/tests/isolation/test_cuda_wheels_and_env_flags.py new file mode 100644 index 000000000..9f0813f9f --- /dev/null +++ b/tests/isolation/test_cuda_wheels_and_env_flags.py @@ -0,0 +1,302 @@ +"""Synthetic integration coverage for manifest plumbing and env flags. + +These tests do not perform a real wheel install or a real ComfyUI E2E run. +""" + +import asyncio +import logging +import os +import sys +from types import SimpleNamespace + +import pytest + +from comfy.isolation import runtime_helpers +from comfy.isolation.extension_loader import ExtensionLoadError, load_isolated_node +from comfy.isolation.extension_wrapper import ComfyNodeExtension +from comfy.isolation.model_patcher_proxy_utils import maybe_wrap_model_for_isolation + + +class _DummyExtension: + def __init__(self) -> None: + self.name = "demo-extension" + + async def stop(self) -> None: + return None + + +def _write_manifest(node_dir, manifest_text: str) -> None: + (node_dir / "pyproject.toml").write_text(manifest_text, encoding="utf-8") + + +def test_load_isolated_node_passes_normalized_cuda_wheels_config(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["flash-attn>=1.0", "sageattention==0.1"] + +[tool.comfy.isolation] +can_isolate = true +share_torch = true + +[tool.comfy.isolation.cuda_wheels] +index_url = "https://example.invalid/cuda-wheels" +packages = ["flash_attn", "sageattention"] + +[tool.comfy.isolation.cuda_wheels.package_map] +flash_attn = "flash-attn-special" +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr( + "comfy.isolation.extension_loader.pyisolate.ExtensionManager", DummyManager + ) + monkeypatch.setattr( + "comfy.isolation.extension_loader.load_host_policy", + lambda base_path: { + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr( + "comfy.isolation.extension_loader.is_cache_valid", lambda *args, **kwargs: True + ) + monkeypatch.setattr( + "comfy.isolation.extension_loader.load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + specs = asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + assert len(specs) == 1 + assert captured["cuda_wheels"] == { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn", "sageattention"], + "package_map": {"flash-attn": "flash-attn-special"}, + } + + +def test_load_isolated_node_rejects_undeclared_cuda_wheel_dependency( + tmp_path, monkeypatch +): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["numpy>=1.0"] + +[tool.comfy.isolation] +can_isolate = true + +[tool.comfy.isolation.cuda_wheels] +index_url = "https://example.invalid/cuda-wheels" +packages = ["flash-attn"] +""".strip(), + ) + + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + with pytest.raises(ExtensionLoadError, match="undeclared dependencies"): + asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + +def test_load_isolated_node_omits_cuda_wheels_when_not_configured(tmp_path, monkeypatch): + node_dir = tmp_path / "node" + node_dir.mkdir() + manifest_path = node_dir / "pyproject.toml" + _write_manifest( + node_dir, + """ +[project] +name = "demo-node" +dependencies = ["numpy>=1.0"] + +[tool.comfy.isolation] +can_isolate = true +""".strip(), + ) + + captured: dict[str, object] = {} + + class DummyManager: + def __init__(self, *args, **kwargs) -> None: + return None + + def load_extension(self, config): + captured.update(config) + return _DummyExtension() + + monkeypatch.setattr( + "comfy.isolation.extension_loader.pyisolate.ExtensionManager", DummyManager + ) + monkeypatch.setattr( + "comfy.isolation.extension_loader.load_host_policy", + lambda base_path: { + "allow_network": False, + "writable_paths": [], + "readonly_paths": [], + }, + ) + monkeypatch.setattr( + "comfy.isolation.extension_loader.is_cache_valid", lambda *args, **kwargs: True + ) + monkeypatch.setattr( + "comfy.isolation.extension_loader.load_from_cache", + lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}}, + ) + monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path))) + + asyncio.run( + load_isolated_node( + node_dir, + manifest_path, + logging.getLogger("test"), + lambda *args, **kwargs: object, + tmp_path / "venvs", + [], + ) + ) + + assert "cuda_wheels" not in captured + + +def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch): + class DummyRegistry: + def register(self, model): + return "model-123" + + class DummyProxy: + def __init__(self, model_id, registry, manage_lifecycle): + self.model_id = model_id + self.registry = registry + self.manage_lifecycle = manage_lifecycle + + monkeypatch.setattr("comfy.isolation.model_patcher_proxy_utils.args.use_process_isolation", True) + monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False) + monkeypatch.delenv("PYISOLATE_CHILD", raising=False) + monkeypatch.setitem( + sys.modules, + "comfy.isolation.model_patcher_proxy_registry", + SimpleNamespace(ModelPatcherRegistry=DummyRegistry), + ) + monkeypatch.setitem( + sys.modules, + "comfy.isolation.model_patcher_proxy", + SimpleNamespace(ModelPatcherProxy=DummyProxy), + ) + + wrapped = maybe_wrap_model_for_isolation(object()) + + assert isinstance(wrapped, DummyProxy) + assert wrapped.model_id == "model-123" + assert wrapped.manage_lifecycle is True + + +def test_flush_transport_state_uses_child_env_without_legacy_flag(monkeypatch): + monkeypatch.setenv("PYISOLATE_CHILD", "1") + monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False) + monkeypatch.setattr( + "comfy.isolation.extension_wrapper._flush_tensor_transport_state", + lambda marker: 3, + ) + monkeypatch.setitem( + sys.modules, + "comfy.isolation.model_patcher_proxy_registry", + SimpleNamespace( + ModelPatcherRegistry=lambda: SimpleNamespace( + sweep_pending_cleanup=lambda: 0 + ) + ), + ) + + flushed = asyncio.run( + ComfyNodeExtension.flush_transport_state(SimpleNamespace(name="demo")) + ) + + assert flushed == 3 + + +def test_build_stub_class_relieves_host_vram_without_legacy_flag(monkeypatch): + import comfy.isolation as isolation_pkg + + relieve_calls: list[str] = [] + + async def deserialize_from_isolation(result, extension): + return result + + monkeypatch.delenv("PYISOLATE_CHILD", raising=False) + monkeypatch.delenv("PYISOLATE_ISOLATION_ACTIVE", raising=False) + monkeypatch.setattr( + runtime_helpers, "_relieve_host_vram_pressure", lambda marker, logger: relieve_calls.append(marker) + ) + monkeypatch.setattr(runtime_helpers, "scan_shm_forensics", lambda *args, **kwargs: None) + monkeypatch.setattr(isolation_pkg, "_RUNNING_EXTENSIONS", {}, raising=False) + monkeypatch.setitem( + sys.modules, + "pyisolate._internal.model_serialization", + SimpleNamespace( + serialize_for_isolation=lambda payload: payload, + deserialize_from_isolation=deserialize_from_isolation, + ), + ) + + class DummyExtension: + name = "demo-extension" + module_path = os.getcwd() + + async def execute_node(self, node_name, **inputs): + return inputs + + stub_cls = runtime_helpers.build_stub_class( + "DemoNode", + {"input_types": {}}, + DummyExtension(), + {}, + logging.getLogger("test"), + ) + + result = asyncio.run( + getattr(stub_cls, "_pyisolate_execute")(SimpleNamespace(), value=1) + ) + + assert relieve_calls == ["RUNTIME:pre_execute"] + assert result == {"value": 1} diff --git a/tests/isolation/test_folder_paths_proxy.py b/tests/isolation/test_folder_paths_proxy.py new file mode 100644 index 000000000..23585647b --- /dev/null +++ b/tests/isolation/test_folder_paths_proxy.py @@ -0,0 +1,111 @@ +"""Unit tests for FolderPathsProxy.""" + +import pytest +from pathlib import Path + +from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + + +class TestFolderPathsProxy: + """Test FolderPathsProxy methods.""" + + @pytest.fixture + def proxy(self): + """Create a FolderPathsProxy instance for testing.""" + return FolderPathsProxy() + + def test_get_temp_directory_returns_string(self, proxy): + """Verify get_temp_directory returns a non-empty string.""" + result = proxy.get_temp_directory() + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Temp directory path is empty" + + def test_get_temp_directory_returns_absolute_path(self, proxy): + """Verify get_temp_directory returns an absolute path.""" + result = proxy.get_temp_directory() + path = Path(result) + assert path.is_absolute(), f"Path is not absolute: {result}" + + def test_get_input_directory_returns_string(self, proxy): + """Verify get_input_directory returns a non-empty string.""" + result = proxy.get_input_directory() + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Input directory path is empty" + + def test_get_input_directory_returns_absolute_path(self, proxy): + """Verify get_input_directory returns an absolute path.""" + result = proxy.get_input_directory() + path = Path(result) + assert path.is_absolute(), f"Path is not absolute: {result}" + + def test_get_annotated_filepath_plain_name(self, proxy): + """Verify get_annotated_filepath works with plain filename.""" + result = proxy.get_annotated_filepath("test.png") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.png" in result, f"Filename not in result: {result}" + + def test_get_annotated_filepath_with_output_annotation(self, proxy): + """Verify get_annotated_filepath handles [output] annotation.""" + result = proxy.get_annotated_filepath("test.png[output]") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.pn" in result, f"Filename base not in result: {result}" + # Should resolve to output directory + assert "output" in result.lower() or Path(result).parent.name == "output" + + def test_get_annotated_filepath_with_input_annotation(self, proxy): + """Verify get_annotated_filepath handles [input] annotation.""" + result = proxy.get_annotated_filepath("test.png[input]") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.pn" in result, f"Filename base not in result: {result}" + + def test_get_annotated_filepath_with_temp_annotation(self, proxy): + """Verify get_annotated_filepath handles [temp] annotation.""" + result = proxy.get_annotated_filepath("test.png[temp]") + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "test.pn" in result, f"Filename base not in result: {result}" + + def test_exists_annotated_filepath_returns_bool(self, proxy): + """Verify exists_annotated_filepath returns a boolean.""" + result = proxy.exists_annotated_filepath("nonexistent.png") + assert isinstance(result, bool), f"Expected bool, got {type(result)}" + + def test_exists_annotated_filepath_nonexistent_file(self, proxy): + """Verify exists_annotated_filepath returns False for nonexistent file.""" + result = proxy.exists_annotated_filepath("definitely_does_not_exist_12345.png") + assert result is False, "Expected False for nonexistent file" + + def test_exists_annotated_filepath_with_annotation(self, proxy): + """Verify exists_annotated_filepath works with annotation suffix.""" + # Even for nonexistent files, should return bool without error + result = proxy.exists_annotated_filepath("test.png[output]") + assert isinstance(result, bool), f"Expected bool, got {type(result)}" + + def test_models_dir_property_returns_string(self, proxy): + """Verify models_dir property returns valid path string.""" + result = proxy.models_dir + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Models directory path is empty" + + def test_models_dir_is_absolute_path(self, proxy): + """Verify models_dir returns an absolute path.""" + result = proxy.models_dir + path = Path(result) + assert path.is_absolute(), f"Path is not absolute: {result}" + + def test_add_model_folder_path_runs_without_error(self, proxy): + """Verify add_model_folder_path executes without raising.""" + test_path = "/tmp/test_models_florence2" + # Should not raise + proxy.add_model_folder_path("TEST_FLORENCE2", test_path) + + def test_get_folder_paths_returns_list(self, proxy): + """Verify get_folder_paths returns a list.""" + # Use known folder type that should exist + result = proxy.get_folder_paths("checkpoints") + assert isinstance(result, list), f"Expected list, got {type(result)}" + + def test_get_folder_paths_checkpoints_not_empty(self, proxy): + """Verify checkpoints folder paths list is not empty.""" + result = proxy.get_folder_paths("checkpoints") + # Should have at least one checkpoint path registered + assert len(result) > 0, "Checkpoints folder paths is empty" diff --git a/tests/isolation/test_host_policy.py b/tests/isolation/test_host_policy.py new file mode 100644 index 000000000..b6097132b --- /dev/null +++ b/tests/isolation/test_host_policy.py @@ -0,0 +1,72 @@ +from pathlib import Path + + +def _write_pyproject(path: Path, content: str) -> None: + path.write_text(content, encoding="utf-8") + + +def test_load_host_policy_defaults_when_pyproject_missing(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + policy = load_host_policy(tmp_path) + + assert policy["allow_network"] == DEFAULT_POLICY["allow_network"] + assert policy["writable_paths"] == DEFAULT_POLICY["writable_paths"] + assert policy["readonly_paths"] == DEFAULT_POLICY["readonly_paths"] + assert policy["whitelist"] == DEFAULT_POLICY["whitelist"] + + +def test_load_host_policy_defaults_when_section_missing(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[project] +name = "ComfyUI" +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["allow_network"] == DEFAULT_POLICY["allow_network"] + assert policy["whitelist"] == {} + + +def test_load_host_policy_reads_values(tmp_path): + from comfy.isolation.host_policy import load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +allow_network = true +writable_paths = ["/tmp/a", "/tmp/b"] +readonly_paths = ["/opt/readonly"] + +[tool.comfy.host.whitelist] +ExampleNode = "*" +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["allow_network"] is True + assert policy["writable_paths"] == ["/tmp/a", "/tmp/b"] + assert policy["readonly_paths"] == ["/opt/readonly"] + assert policy["whitelist"] == {"ExampleNode": "*"} + + +def test_load_host_policy_ignores_invalid_whitelist_type(tmp_path): + from comfy.isolation.host_policy import DEFAULT_POLICY, load_host_policy + + _write_pyproject( + tmp_path / "pyproject.toml", + """ +[tool.comfy.host] +allow_network = true +whitelist = ["bad"] +""".strip(), + ) + + policy = load_host_policy(tmp_path) + assert policy["allow_network"] is True + assert policy["whitelist"] == DEFAULT_POLICY["whitelist"] diff --git a/tests/isolation/test_init.py b/tests/isolation/test_init.py new file mode 100644 index 000000000..d9dfeb1e6 --- /dev/null +++ b/tests/isolation/test_init.py @@ -0,0 +1,56 @@ +"""Unit tests for PyIsolate isolation system initialization.""" + + + +def test_log_prefix(): + """Verify LOG_PREFIX constant is correctly defined.""" + from comfy.isolation import LOG_PREFIX + assert LOG_PREFIX == "][" + assert isinstance(LOG_PREFIX, str) + + +def test_module_initialization(): + """Verify module initializes without errors.""" + import comfy.isolation + assert hasattr(comfy.isolation, 'LOG_PREFIX') + assert hasattr(comfy.isolation, 'initialize_proxies') + + +class TestInitializeProxies: + def test_initialize_proxies_runs_without_error(self): + from comfy.isolation import initialize_proxies + initialize_proxies() + + def test_initialize_proxies_registers_folder_paths_proxy(self): + from comfy.isolation import initialize_proxies + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + initialize_proxies() + proxy = FolderPathsProxy() + assert proxy is not None + assert hasattr(proxy, "get_temp_directory") + + def test_initialize_proxies_registers_model_management_proxy(self): + from comfy.isolation import initialize_proxies + from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + initialize_proxies() + proxy = ModelManagementProxy() + assert proxy is not None + assert hasattr(proxy, "get_torch_device") + + def test_initialize_proxies_can_be_called_multiple_times(self): + from comfy.isolation import initialize_proxies + initialize_proxies() + initialize_proxies() + initialize_proxies() + + def test_dev_proxies_accessible_when_dev_mode(self, monkeypatch): + """Verify dev mode does not break core proxy initialization.""" + monkeypatch.setenv("PYISOLATE_DEV", "1") + from comfy.isolation import initialize_proxies + from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy + from comfy.isolation.proxies.utils_proxy import UtilsProxy + initialize_proxies() + folder_proxy = FolderPathsProxy() + utils_proxy = UtilsProxy() + assert folder_proxy is not None + assert utils_proxy is not None diff --git a/tests/isolation/test_manifest_loader_cache.py b/tests/isolation/test_manifest_loader_cache.py new file mode 100644 index 000000000..ebee43b7e --- /dev/null +++ b/tests/isolation/test_manifest_loader_cache.py @@ -0,0 +1,434 @@ +""" +Unit tests for manifest_loader.py cache functions. + +Phase 1 tests verify: +1. Cache miss on first run (no cache exists) +2. Cache hit when nothing changes +3. Invalidation on .py file touch +4. Invalidation on manifest change +5. Cache location correctness (in venv_root, NOT in custom_nodes) +6. Corrupt cache handling (graceful failure) + +These tests verify the cache implementation is correct BEFORE it's activated +in extension_loader.py (Phase 2). +""" + +from __future__ import annotations + +import json +import sys +import time +from pathlib import Path +from unittest import mock + + + +class TestComputeCacheKey: + """Tests for compute_cache_key() function.""" + + def test_key_includes_manifest_content(self, tmp_path: Path) -> None: + """Cache key changes when manifest content changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + + # Initial manifest + manifest.write_text("isolated: true\ndependencies: []\n") + key1 = compute_cache_key(node_dir, manifest) + + # Modified manifest + manifest.write_text("isolated: true\ndependencies: [numpy]\n") + key2 = compute_cache_key(node_dir, manifest) + + assert key1 != key2, "Key should change when manifest content changes" + + def test_key_includes_py_file_mtime(self, tmp_path: Path) -> None: + """Cache key changes when any .py file is touched.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + py_file = node_dir / "nodes.py" + py_file.write_text("# test code") + + key1 = compute_cache_key(node_dir, manifest) + + # Wait a moment to ensure mtime changes + time.sleep(0.01) + py_file.write_text("# modified code") + + key2 = compute_cache_key(node_dir, manifest) + + assert key1 != key2, "Key should change when .py file mtime changes" + + def test_key_includes_python_version(self, tmp_path: Path) -> None: + """Cache key changes when Python version changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + key1 = compute_cache_key(node_dir, manifest) + + # Mock different Python version + with mock.patch.object(sys, "version", "3.99.0 (fake)"): + key2 = compute_cache_key(node_dir, manifest) + + assert key1 != key2, "Key should change when Python version changes" + + def test_key_includes_pyisolate_version(self, tmp_path: Path) -> None: + """Cache key changes when PyIsolate version changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + key1 = compute_cache_key(node_dir, manifest) + + # Mock different pyisolate version + with mock.patch.dict(sys.modules, {"pyisolate": mock.MagicMock(__version__="99.99.99")}): + # Need to reimport to pick up the mock + import importlib + from comfy.isolation import manifest_loader + importlib.reload(manifest_loader) + key2 = manifest_loader.compute_cache_key(node_dir, manifest) + + # Keys should be different (though the mock approach is tricky) + # At minimum, verify key is a valid hex string + assert len(key1) == 16, "Key should be 16 hex characters" + assert all(c in "0123456789abcdef" for c in key1), "Key should be hex" + assert len(key2) == 16, "Key should be 16 hex characters" + assert all(c in "0123456789abcdef" for c in key2), "Key should be hex" + + def test_key_excludes_pycache(self, tmp_path: Path) -> None: + """Cache key ignores __pycache__ directory changes.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + + py_file = node_dir / "nodes.py" + py_file.write_text("# test code") + + key1 = compute_cache_key(node_dir, manifest) + + # Add __pycache__ file + pycache = node_dir / "__pycache__" + pycache.mkdir() + (pycache / "nodes.cpython-310.pyc").write_bytes(b"compiled") + + key2 = compute_cache_key(node_dir, manifest) + + assert key1 == key2, "Key should NOT change when __pycache__ modified" + + def test_key_is_deterministic(self, tmp_path: Path) -> None: + """Same inputs produce same key.""" + from comfy.isolation.manifest_loader import compute_cache_key + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + (node_dir / "nodes.py").write_text("# code") + + key1 = compute_cache_key(node_dir, manifest) + key2 = compute_cache_key(node_dir, manifest) + + assert key1 == key2, "Key should be deterministic" + + +class TestGetCachePath: + """Tests for get_cache_path() function.""" + + def test_returns_correct_paths(self, tmp_path: Path) -> None: + """Cache paths are in venv_root, not in node_dir.""" + from comfy.isolation.manifest_loader import get_cache_path + + node_dir = tmp_path / "custom_nodes" / "MyNode" + venv_root = tmp_path / ".pyisolate_venvs" + + key_file, data_file = get_cache_path(node_dir, venv_root) + + assert key_file == venv_root / "MyNode" / "cache" / "cache_key" + assert data_file == venv_root / "MyNode" / "cache" / "node_info.json" + + def test_cache_not_in_custom_nodes(self, tmp_path: Path) -> None: + """Verify cache is NOT stored in custom_nodes directory.""" + from comfy.isolation.manifest_loader import get_cache_path + + node_dir = tmp_path / "custom_nodes" / "MyNode" + venv_root = tmp_path / ".pyisolate_venvs" + + key_file, data_file = get_cache_path(node_dir, venv_root) + + # Neither path should be under node_dir + assert not str(key_file).startswith(str(node_dir)) + assert not str(data_file).startswith(str(node_dir)) + + +class TestIsCacheValid: + """Tests for is_cache_valid() function.""" + + def test_false_when_no_cache_exists(self, tmp_path: Path) -> None: + """Returns False when cache files don't exist.""" + from comfy.isolation.manifest_loader import is_cache_valid + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + assert is_cache_valid(node_dir, manifest, venv_root) is False + + def test_true_when_cache_matches(self, tmp_path: Path) -> None: + """Returns True when cache key matches current state.""" + from comfy.isolation.manifest_loader import ( + compute_cache_key, + get_cache_path, + is_cache_valid, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + (node_dir / "nodes.py").write_text("# code") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create valid cache + cache_key = compute_cache_key(node_dir, manifest) + key_file, data_file = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text(cache_key) + data_file.write_text("{}") + + assert is_cache_valid(node_dir, manifest, venv_root) is True + + def test_false_when_key_mismatch(self, tmp_path: Path) -> None: + """Returns False when stored key doesn't match current state.""" + from comfy.isolation.manifest_loader import get_cache_path, is_cache_valid + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create cache with wrong key + key_file, data_file = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text("wrong_key_12345") + data_file.write_text("{}") + + assert is_cache_valid(node_dir, manifest, venv_root) is False + + def test_false_when_data_file_missing(self, tmp_path: Path) -> None: + """Returns False when node_info.json is missing.""" + from comfy.isolation.manifest_loader import ( + compute_cache_key, + get_cache_path, + is_cache_valid, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create only key file, not data file + cache_key = compute_cache_key(node_dir, manifest) + key_file, _ = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text(cache_key) + + assert is_cache_valid(node_dir, manifest, venv_root) is False + + def test_invalidation_on_py_change(self, tmp_path: Path) -> None: + """Cache invalidates when .py file is modified.""" + from comfy.isolation.manifest_loader import ( + compute_cache_key, + get_cache_path, + is_cache_valid, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + py_file = node_dir / "nodes.py" + py_file.write_text("# original") + venv_root = tmp_path / ".pyisolate_venvs" + + # Create valid cache + cache_key = compute_cache_key(node_dir, manifest) + key_file, data_file = get_cache_path(node_dir, venv_root) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text(cache_key) + data_file.write_text("{}") + + # Verify cache is valid initially + assert is_cache_valid(node_dir, manifest, venv_root) is True + + # Modify .py file + time.sleep(0.01) # Ensure mtime changes + py_file.write_text("# modified") + + # Cache should now be invalid + assert is_cache_valid(node_dir, manifest, venv_root) is False + + +class TestLoadFromCache: + """Tests for load_from_cache() function.""" + + def test_returns_none_when_no_cache(self, tmp_path: Path) -> None: + """Returns None when cache doesn't exist.""" + from comfy.isolation.manifest_loader import load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + assert load_from_cache(node_dir, venv_root) is None + + def test_returns_data_when_valid(self, tmp_path: Path) -> None: + """Returns cached data when file exists and is valid JSON.""" + from comfy.isolation.manifest_loader import get_cache_path, load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + test_data = {"TestNode": {"inputs": [], "outputs": []}} + + _, data_file = get_cache_path(node_dir, venv_root) + data_file.parent.mkdir(parents=True, exist_ok=True) + data_file.write_text(json.dumps(test_data)) + + result = load_from_cache(node_dir, venv_root) + assert result == test_data + + def test_returns_none_on_corrupt_json(self, tmp_path: Path) -> None: + """Returns None when JSON is corrupt.""" + from comfy.isolation.manifest_loader import get_cache_path, load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + _, data_file = get_cache_path(node_dir, venv_root) + data_file.parent.mkdir(parents=True, exist_ok=True) + data_file.write_text("{ corrupt json }") + + assert load_from_cache(node_dir, venv_root) is None + + def test_returns_none_on_invalid_structure(self, tmp_path: Path) -> None: + """Returns None when data is not a dict.""" + from comfy.isolation.manifest_loader import get_cache_path, load_from_cache + + node_dir = tmp_path / "test_node" + venv_root = tmp_path / ".pyisolate_venvs" + + _, data_file = get_cache_path(node_dir, venv_root) + data_file.parent.mkdir(parents=True, exist_ok=True) + data_file.write_text("[1, 2, 3]") # Array, not dict + + assert load_from_cache(node_dir, venv_root) is None + + +class TestSaveToCache: + """Tests for save_to_cache() function.""" + + def test_creates_cache_directory(self, tmp_path: Path) -> None: + """Creates cache directory if it doesn't exist.""" + from comfy.isolation.manifest_loader import get_cache_path, save_to_cache + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + save_to_cache(node_dir, venv_root, {"TestNode": {}}, manifest) + + key_file, data_file = get_cache_path(node_dir, venv_root) + assert key_file.parent.exists() + + def test_writes_both_files(self, tmp_path: Path) -> None: + """Writes both cache_key and node_info.json.""" + from comfy.isolation.manifest_loader import get_cache_path, save_to_cache + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + save_to_cache(node_dir, venv_root, {"TestNode": {"key": "value"}}, manifest) + + key_file, data_file = get_cache_path(node_dir, venv_root) + assert key_file.exists() + assert data_file.exists() + + def test_data_is_valid_json(self, tmp_path: Path) -> None: + """Written data can be parsed as JSON.""" + from comfy.isolation.manifest_loader import get_cache_path, save_to_cache + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + test_data = {"TestNode": {"inputs": ["IMAGE"], "outputs": ["IMAGE"]}} + save_to_cache(node_dir, venv_root, test_data, manifest) + + _, data_file = get_cache_path(node_dir, venv_root) + loaded = json.loads(data_file.read_text()) + assert loaded == test_data + + def test_roundtrip_with_validation(self, tmp_path: Path) -> None: + """Saved cache is immediately valid.""" + from comfy.isolation.manifest_loader import ( + is_cache_valid, + load_from_cache, + save_to_cache, + ) + + node_dir = tmp_path / "test_node" + node_dir.mkdir() + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + (node_dir / "nodes.py").write_text("# code") + venv_root = tmp_path / ".pyisolate_venvs" + + test_data = {"TestNode": {"foo": "bar"}} + save_to_cache(node_dir, venv_root, test_data, manifest) + + assert is_cache_valid(node_dir, manifest, venv_root) is True + assert load_from_cache(node_dir, venv_root) == test_data + + def test_cache_not_in_custom_nodes(self, tmp_path: Path) -> None: + """Verify no files written to custom_nodes directory.""" + from comfy.isolation.manifest_loader import save_to_cache + + node_dir = tmp_path / "custom_nodes" / "MyNode" + node_dir.mkdir(parents=True) + manifest = node_dir / "pyisolate.yaml" + manifest.write_text("isolated: true\n") + venv_root = tmp_path / ".pyisolate_venvs" + + save_to_cache(node_dir, venv_root, {"TestNode": {}}, manifest) + + # Check nothing was created under node_dir + for item in node_dir.iterdir(): + assert item.name == "pyisolate.yaml", f"Unexpected file in node_dir: {item}" diff --git a/tests/isolation/test_model_management_proxy.py b/tests/isolation/test_model_management_proxy.py new file mode 100644 index 000000000..3a03bd54d --- /dev/null +++ b/tests/isolation/test_model_management_proxy.py @@ -0,0 +1,50 @@ +"""Unit tests for ModelManagementProxy.""" + +import pytest +import torch + +from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy + + +class TestModelManagementProxy: + """Test ModelManagementProxy methods.""" + + @pytest.fixture + def proxy(self): + """Create a ModelManagementProxy instance for testing.""" + return ModelManagementProxy() + + def test_get_torch_device_returns_device(self, proxy): + """Verify get_torch_device returns a torch.device object.""" + result = proxy.get_torch_device() + assert isinstance(result, torch.device), f"Expected torch.device, got {type(result)}" + + def test_get_torch_device_is_valid(self, proxy): + """Verify get_torch_device returns a valid device (cpu or cuda).""" + result = proxy.get_torch_device() + assert result.type in ("cpu", "cuda"), f"Unexpected device type: {result.type}" + + def test_get_torch_device_name_returns_string(self, proxy): + """Verify get_torch_device_name returns a non-empty string.""" + device = proxy.get_torch_device() + result = proxy.get_torch_device_name(device) + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert len(result) > 0, "Device name is empty" + + def test_get_torch_device_name_with_cpu(self, proxy): + """Verify get_torch_device_name works with CPU device.""" + cpu_device = torch.device("cpu") + result = proxy.get_torch_device_name(cpu_device) + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert "cpu" in result.lower(), f"Expected 'cpu' in device name, got: {result}" + + def test_get_torch_device_name_with_cuda_if_available(self, proxy): + """Verify get_torch_device_name works with CUDA device if available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + cuda_device = torch.device("cuda:0") + result = proxy.get_torch_device_name(cuda_device) + assert isinstance(result, str), f"Expected str, got {type(result)}" + # Should contain device identifier + assert len(result) > 0, "CUDA device name is empty" diff --git a/tests/isolation/test_path_helpers.py b/tests/isolation/test_path_helpers.py new file mode 100644 index 000000000..af96f1fe0 --- /dev/null +++ b/tests/isolation/test_path_helpers.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path + +import pytest + +from pyisolate.path_helpers import build_child_sys_path, serialize_host_snapshot + + +def test_serialize_host_snapshot_includes_expected_keys(tmp_path: Path, monkeypatch) -> None: + output = tmp_path / "snapshot.json" + monkeypatch.setenv("EXTRA_FLAG", "1") + snapshot = serialize_host_snapshot(output_path=output, extra_env_keys=["EXTRA_FLAG"]) + + assert "sys_path" in snapshot + assert "sys_executable" in snapshot + assert "sys_prefix" in snapshot + assert "environment" in snapshot + assert output.exists() + assert snapshot["environment"].get("EXTRA_FLAG") == "1" + + persisted = json.loads(output.read_text(encoding="utf-8")) + assert persisted["sys_path"] == snapshot["sys_path"] + + +def test_build_child_sys_path_preserves_host_order() -> None: + host_paths = ["/host/root", "/host/site-packages"] + extra_paths = ["/node/.venv/lib/python3.12/site-packages"] + result = build_child_sys_path(host_paths, extra_paths, preferred_root=None) + assert result == host_paths + extra_paths + + +def test_build_child_sys_path_inserts_comfy_root_when_missing() -> None: + host_paths = ["/host/site-packages"] + comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") + extra_paths: list[str] = [] + result = build_child_sys_path(host_paths, extra_paths, preferred_root=comfy_root) + assert result[0] == comfy_root + assert result[1:] == host_paths + + +def test_build_child_sys_path_deduplicates_entries(tmp_path: Path) -> None: + path_a = str(tmp_path / "a") + path_b = str(tmp_path / "b") + host_paths = [path_a, path_b] + extra_paths = [path_a, path_b, str(tmp_path / "c")] + result = build_child_sys_path(host_paths, extra_paths) + assert result == [path_a, path_b, str(tmp_path / "c")] + + +def test_build_child_sys_path_skips_duplicate_comfy_root() -> None: + comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") + host_paths = [comfy_root, "/host/other"] + result = build_child_sys_path(host_paths, extra_paths=[], preferred_root=comfy_root) + assert result == host_paths + + +def test_child_import_succeeds_after_path_unification(tmp_path: Path, monkeypatch) -> None: + host_root = tmp_path / "host" + utils_pkg = host_root / "utils" + app_pkg = host_root / "app" + utils_pkg.mkdir(parents=True) + app_pkg.mkdir(parents=True) + + (utils_pkg / "__init__.py").write_text("from . import install_util\n", encoding="utf-8") + (utils_pkg / "install_util.py").write_text("VALUE = 'hello'\n", encoding="utf-8") + (app_pkg / "__init__.py").write_text("", encoding="utf-8") + (app_pkg / "frontend_management.py").write_text( + "from utils import install_util\nVALUE = install_util.VALUE\n", + encoding="utf-8", + ) + + child_only = tmp_path / "child_only" + child_only.mkdir() + + target_module = "app.frontend_management" + for name in [n for n in list(sys.modules) if n.startswith("app") or n.startswith("utils")]: + sys.modules.pop(name) + + monkeypatch.setattr(sys, "path", [str(child_only)]) + with pytest.raises(ModuleNotFoundError): + __import__(target_module) + + for name in [n for n in list(sys.modules) if n.startswith("app") or n.startswith("utils")]: + sys.modules.pop(name) + + unified = build_child_sys_path([], [], preferred_root=str(host_root)) + monkeypatch.setattr(sys, "path", unified) + module = __import__(target_module, fromlist=["VALUE"]) + assert module.VALUE == "hello" diff --git a/tests/test_adapter.py b/tests/test_adapter.py new file mode 100644 index 000000000..feaa62549 --- /dev/null +++ b/tests/test_adapter.py @@ -0,0 +1,51 @@ +import os +import sys +from pathlib import Path + +repo_root = Path(__file__).resolve().parents[1] +pyisolate_root = repo_root.parent / "pyisolate" +if pyisolate_root.exists(): + sys.path.insert(0, str(pyisolate_root)) + +from comfy.isolation.adapter import ComfyUIAdapter +from pyisolate._internal.serialization_registry import SerializerRegistry + + +def test_identifier(): + adapter = ComfyUIAdapter() + assert adapter.identifier == "comfyui" + + +def test_get_path_config_valid(): + adapter = ComfyUIAdapter() + path = os.path.join("/opt", "ComfyUI", "custom_nodes", "demo") + cfg = adapter.get_path_config(path) + assert cfg is not None + assert cfg["preferred_root"].endswith("ComfyUI") + assert "custom_nodes" in cfg["additional_paths"][0] + + +def test_get_path_config_invalid(): + adapter = ComfyUIAdapter() + assert adapter.get_path_config("/random/path") is None + + +def test_provide_rpc_services(): + adapter = ComfyUIAdapter() + services = adapter.provide_rpc_services() + names = {s.__name__ for s in services} + assert "PromptServerService" in names + assert "FolderPathsProxy" in names + + +def test_register_serializers(): + adapter = ComfyUIAdapter() + registry = SerializerRegistry.get_instance() + registry.clear() + + adapter.register_serializers(registry) + assert registry.has_handler("ModelPatcher") + assert registry.has_handler("CLIP") + assert registry.has_handler("VAE") + + registry.clear()