diff --git a/.coderabbit.yaml b/.coderabbit.yaml index 3849d4b26..0d1e49270 100644 --- a/.coderabbit.yaml +++ b/.coderabbit.yaml @@ -1,6 +1,7 @@ # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json language: "en-US" early_access: false +tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code." reviews: profile: "chill" @@ -35,6 +36,14 @@ reviews: - "!**/*.bat" path_instructions: + - path: "**" + instructions: | + IMPORTANT: Only comment on issues directly introduced by this PR's code changes. + Do NOT flag pre-existing issues in code that was merely moved, re-indented, + de-indented, or reformatted without logic changes. If code appears in the diff + only due to whitespace or structural reformatting (e.g., removing a `with:` block), + treat it as unchanged. Contributors should not feel obligated to address + pre-existing issues outside the scope of their contribution. - path: "comfy/**" instructions: | Core ML/diffusion engine. Focus on: @@ -74,7 +83,11 @@ reviews: auto_review: enabled: true auto_incremental_review: true - drafts: true + drafts: false + ignore_title_keywords: + - "WIP" + - "DO NOT REVIEW" + - "DO NOT MERGE" finishing_touches: docstrings: @@ -84,7 +97,7 @@ reviews: tools: ruff: - enabled: true + enabled: false pylint: enabled: false flake8: diff --git a/app/subgraph_manager.py b/app/subgraph_manager.py index 6a8f586a4..08ad8c302 100644 --- a/app/subgraph_manager.py +++ b/app/subgraph_manager.py @@ -53,7 +53,7 @@ class SubgraphManager: return entry_id, entry async def load_entry_data(self, entry: SubgraphEntry): - with open(entry['path'], 'r') as f: + with open(entry['path'], 'r', encoding='utf-8') as f: entry['data'] = f.read() return entry diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 2c6954ecd..2b080aaeb 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import ( LTXVModel, ) from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import comfy.ldm.common_dit class CompressedTimestep: @@ -450,6 +451,29 @@ class LTXAVModel(LTXVModel): operations=self.operations, ) + self.audio_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=self.operations, + ) + + self.video_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=self.operations, + ) + + def preprocess_text_embeds(self, context): + if context.shape[-1] == self.caption_channels * 2: + return context + out_vid = self.video_embeddings_connector(context)[0] + out_audio = self.audio_embeddings_connector(context)[0] + return torch.concat((out_vid, out_audio), dim=-1) + def _init_transformer_blocks(self, device, dtype, **kwargs): """Initialize transformer blocks for LTXAV.""" self.transformer_blocks = nn.ModuleList( diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index 06f5ada89..33adb9671 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -157,11 +157,9 @@ class Embeddings1DConnector(nn.Module): self.num_learnable_registers = num_learnable_registers if self.num_learnable_registers: self.learnable_registers = nn.Parameter( - torch.rand( + torch.empty( self.num_learnable_registers, inner_dim, dtype=dtype, device=device ) - * 2.0 - - 1.0 ) def get_fractional_positions(self, indices_grid): @@ -234,7 +232,7 @@ class Embeddings1DConnector(nn.Module): return indices - def precompute_freqs_cis(self, indices_grid, spacing="exp"): + def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None): dim = self.inner_dim n_elem = 2 # 2 because of cos and sin freqs = self.precompute_freqs(indices_grid, spacing) @@ -247,7 +245,7 @@ class Embeddings1DConnector(nn.Module): ) else: cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) - return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope + return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope def forward( self, @@ -288,7 +286,7 @@ class Embeddings1DConnector(nn.Module): hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device ) indices_grid = indices_grid[None, None, :] - freqs_cis = self.precompute_freqs_cis(indices_grid) + freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype) # 2. Blocks for block_idx, block in enumerate(self.transformer_1d_blocks): diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 858bd4cc7..0b7da2852 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered): return dest_views -aimdo_allocator = None +aimdo_enabled = False diff --git a/comfy/model_base.py b/comfy/model_base.py index 9dcef8741..2f49578f6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -988,10 +988,14 @@ class LTXAV(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) + device = kwargs["device"] + if attention_mask is not None: out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: + if hasattr(self.diffusion_model, "preprocess_text_embeds"): + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference())) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) diff --git a/comfy/model_management.py b/comfy/model_management.py index 38c3e482b..1fe56a62b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None: + if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: return torch_dev else: return cpu_dev @@ -1121,7 +1121,6 @@ def get_cast_buffer(offload_stream, device, size, ref): synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer - #FIXME: This doesn't work in Aimdo because mempool cant clear cache soft_empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 82fbacf59..e2ce22e37 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -3,7 +3,6 @@ import os from transformers import T5TokenizerFast from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.genmo -from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import torch import comfy.utils import math @@ -109,22 +108,6 @@ class LTXAVTEModel(torch.nn.Module): operations = self.gemma3_12b.operations # TODO self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) - self.audio_embeddings_connector = Embeddings1DConnector( - split_rope=True, - double_precision_rope=True, - dtype=dtype, - device=device, - operations=operations, - ) - - self.video_embeddings_connector = Embeddings1DConnector( - split_rope=True, - double_precision_rope=True, - dtype=dtype, - device=device, - operations=operations, - ) - def set_clip_options(self, options): self.execution_device = options.get("execution_device", self.execution_device) self.gemma3_12b.set_clip_options(options) @@ -146,10 +129,6 @@ class LTXAVTEModel(torch.nn.Module): out = out.reshape((out.shape[0], out.shape[1], -1)) out = self.text_embedding_projection(out) out = out.float() - out_vid = self.video_embeddings_connector(out)[0] - out_audio = self.audio_embeddings_connector(out)[0] - out = torch.concat((out_vid, out_audio), dim=-1) - return out.to(out_device), pooled def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): @@ -159,14 +138,14 @@ class LTXAVTEModel(torch.nn.Module): if "model.layers.47.self_attn.q_norm.weight" in sd: return self.gemma3_12b.load_sd(sd) else: - sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True) if len(sdo) == 0: sdo = sd missing_all = [] unexpected_all = [] - for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: + for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]: component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} if component_sd: missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) diff --git a/comfy/utils.py b/comfy/utils.py index 518757d98..1558d0a4e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1154,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) def model_trange(*args, **kwargs): - if comfy.memory_management.aimdo_allocator is None: + if not comfy.memory_management.aimdo_enabled: return trange(*args, **kwargs) pbar = trange(*args, **kwargs, smoothing=1.0) diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py index 18a35d846..75ffb6d80 100644 --- a/comfy_extras/nodes_glsl.py +++ b/comfy_extras/nodes_glsl.py @@ -716,12 +716,12 @@ def _render_shader_batch( gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) gl.glUseProgram(0) - if input_textures: - gl.glDeleteTextures(len(input_textures), input_textures) - if output_textures: - gl.glDeleteTextures(len(output_textures), output_textures) - if ping_pong_textures: - gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures) + for tex in input_textures: + gl.glDeleteTextures(tex) + for tex in output_textures: + gl.glDeleteTextures(tex) + for tex in ping_pong_textures: + gl.glDeleteTextures(tex) if fbo is not None: gl.glDeleteFramebuffers(1, [fbo]) for pp_fbo in ping_pong_fbos: diff --git a/comfy_extras/nodes_nag.py b/comfy_extras/nodes_nag.py index 033e40eb9..b57181848 100644 --- a/comfy_extras/nodes_nag.py +++ b/comfy_extras/nodes_nag.py @@ -10,7 +10,7 @@ class NAGuidance(io.ComfyNode): node_id="NAGuidance", display_name="Normalized Attention Guidance", description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.", - category="", + category="advanced/guidance", is_experimental=True, inputs=[ io.Model.Input("model", tooltip="The model to apply NAG to."), diff --git a/cuda_malloc.py b/cuda_malloc.py index b2182df37..f7651981c 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,10 +1,8 @@ import os import importlib.util -from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram +from comfy.cli_args import args, PerformanceFeature import subprocess -import comfy_aimdo.control - #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): if os.name == 'nt': @@ -87,10 +85,6 @@ if not args.cuda_malloc: except: pass -if enables_dynamic_vram() and comfy_aimdo.control.init(): - args.cuda_malloc = False - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "" - if args.disable_cuda_malloc: args.cuda_malloc = False diff --git a/execution.py b/execution.py index f549a2f0f..75b021892 100644 --- a/execution.py +++ b/execution.py @@ -9,7 +9,6 @@ import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union import asyncio -from contextlib import nullcontext import torch @@ -521,19 +520,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows - #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc - #that we just want to cull out each model run. - allocator = comfy.memory_management.aimdo_allocator - with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): - try: - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) - finally: - if allocator is not None: - if args.verbose == "DEBUG": - comfy_aimdo.model_vbar.vbars_analyze() - comfy.model_management.reset_cast_buffers() - comfy_aimdo.model_vbar.vbars_reset_watermark_limits() + try: + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + finally: + if comfy.memory_management.aimdo_enabled: + if args.verbose == "DEBUG": + comfy_aimdo.control.analyze() + comfy.model_management.reset_cast_buffers() + comfy_aimdo.model_vbar.vbars_reset_watermark_limits() if has_pending_tasks: pending_async_nodes[unique_id] = output_data diff --git a/main.py b/main.py index 92d705b4d..39e605deb 100644 --- a/main.py +++ b/main.py @@ -173,6 +173,10 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() import comfy.utils @@ -188,13 +192,9 @@ import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher -import comfy_aimdo.control -import comfy_aimdo.torch - if enables_dynamic_vram(): if comfy.model_management.torch_version_numeric < (2, 8): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - comfy.memory_management.aimdo_allocator = None elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() @@ -208,11 +208,10 @@ if enables_dynamic_vram(): comfy_aimdo.control.set_log_info() comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic - comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() + comfy.memory_management.aimdo_enabled = True logging.info("DynamicVRAM support detected and enabled") else: logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - comfy.memory_management.aimdo_allocator = None def cuda_malloc_warning(): diff --git a/requirements.txt b/requirements.txt index 3a9bfde46..8fbb0dbd6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.8 +comfy-aimdo>=0.2.0 requests #non essential dependencies: