Compare commits

..

1 Commits

Author SHA1 Message Date
Yousef R. Gamaleldin
eedecee439
Merge f1d25a460c into dd86b15521 2026-02-02 18:14:01 +02:00
5 changed files with 35 additions and 89 deletions

View File

@ -1,9 +1,9 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
from vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
from typing import Optional, Tuple, Literal, Union, List
from comfy.ldm.trellis2.attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention
from attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
class SparseGELU(nn.GELU):

View File

@ -19,8 +19,7 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import threading
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
import platform
@ -651,7 +650,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@ -747,25 +746,8 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
current_loaded_models.insert(0, loaded_model)
return
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
with torch.inference_mode():
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
soft_empty_cache()
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
#thread local. So exploit that to escape context
if enables_dynamic_vram():
t = threading.Thread(
target=load_models_gpu_thread,
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
)
t.start()
t.join()
else:
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
def load_model_gpu(model):
return load_models_gpu([model])
def loaded_models(only_currently_used=False):
output = []
@ -1130,11 +1112,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
return None
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
synchronize()
torch.cuda.synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
soft_empty_cache()
torch.cuda.empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
@ -1150,7 +1132,9 @@ def reset_cast_buffers():
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
STREAM_CAST_BUFFERS.clear()
soft_empty_cache()
if comfy.memory_management.aimdo_allocator is None:
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
torch.cuda.empty_cache()
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
@ -1300,7 +1284,7 @@ def discard_cuda_async_error():
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
synchronize()
torch.cuda.synchronize()
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass
@ -1704,12 +1688,6 @@ def lora_compute_dtype(device):
LORA_COMPUTE_DTYPES[device] = dtype
return dtype
def synchronize():
if is_intel_xpu():
torch.xpu.synchronize()
elif torch.cuda.is_available():
torch.cuda.synchronize()
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
@ -1735,6 +1713,9 @@ def debug_memory_summary():
return torch.cuda.memory.memory_summary()
return ""
#TODO: might be cleaner to put this somewhere else
import threading
class InterruptProcessingException(Exception):
pass

View File

@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32)
self.partially_unload(None)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above

View File

@ -193,50 +193,7 @@ class Trellis2Conditioning(IO.ComfyNode):
negative = [[conditioning["cond_neg"], {embeds}]]
return IO.NodeOutput(positive, negative)
class EmptyShapeLatentTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentTrellis2",
category="latent/3d",
inputs=[
IO.Latent.Input("structure_output"),
],
outputs=[
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, structure_output):
# i will see what i have to do here
coords = structure_output or structure_output.coords
in_channels = 32
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class EmptyTextureLatentTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentTrellis2",
category="latent/3d",
inputs=[
IO.Latent.Input("structure_output"),
],
outputs=[
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, structure_output):
# TODO
in_channels = 32
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class EmptyStructureLatentTrellis2(IO.ComfyNode):
class EmptyLatentTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
@ -245,26 +202,35 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
inputs=[
IO.Int.Input("resolution", default=3072, min=1, max=8192),
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
IO.Vae.Input("vae"),
IO.Boolean.Input("shape_generation", tooltip="Setting to false will generate texture."),
IO.MultiCombo.Input("generation_type", options=["structure_generation", "shape_generation", "texture_generation"])
],
outputs=[
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, res, batch_size):
in_channels = 32
latent = torch.randn(batch_size, in_channels, res, res, res)
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
@classmethod
def execute(cls, batch_size, coords, vae, generation_type) -> IO.NodeOutput:
# TODO: i will probably update how shape/texture is generated
# could split this too
in_channels = 32
shape_generation = generation_type == "shape_generation"
device = comfy.model_management.intermediate_device()
if shape_generation:
latent = SparseTensor(feats=torch.randn(batch_size, in_channels).to(device), coords=coords)
else:
# coords = shape_slat in txt gen case
latent = coords.replace(feats=torch.randn(coords.coords.shape[0], in_channels - coords.feats.shape[1]).to(device))
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class Trellis2Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Trellis2Conditioning,
EmptyShapeLatentTrellis2,
EmptyStructureLatentTrellis2,
EmptyTextureLatentTrellis2,
EmptyLatentTrellis2,
VaeDecodeTextureTrellis,
VaeDecodeShapeTrellis
]

View File

@ -2433,8 +2433,7 @@ async def init_builtin_extra_nodes():
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_lora_debug.py",
"nodes_color.py",
"nodes_trellis2.py"
"nodes_color.py"
]
import_failed = []