small bug fixes

This commit is contained in:
Yousef Rafat 2026-02-09 00:41:01 +02:00
parent 955c00ee38
commit 704e1b5462
6 changed files with 21 additions and 18 deletions

View File

@ -2,7 +2,7 @@ import torch
import math
from comfy.ldm.modules.attention import optimized_attention
from typing import Tuple, Union, List
from vae import VarLenTensor
from comfy.ldm.trellis2.vae import VarLenTensor
FLASH_ATTN_3_AVA = True
try:

View File

@ -798,9 +798,12 @@ class Trellis2(nn.Module):
qk_rms_norm = True,
qk_rms_norm_cross = True,
init_txt_model=False, # for now
dtype=None, device=None, operations=None):
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
# for some reason it passes num_heads = -1
num_heads = 12
args = {
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,

View File

@ -6,7 +6,7 @@ from fractions import Fraction
import torch.nn.functional as F
from dataclasses import dataclass
import numpy as np
from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
@ -1457,10 +1457,9 @@ class SparseStructureDecoder(nn.Module):
return h
class Vae(nn.Module):
def __init__(self, config, operations=None):
def __init__(self, init_txt_model, operations=None):
super().__init__()
operations = operations or torch.nn
init_txt_model = config.get("init_txt_model", False)
if init_txt_model:
self.txt_dec = SparseUnetVaeDecoder(
out_channels=6,

View File

@ -109,15 +109,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
unet_config = {}
unet_config["image_model"] = "trellis2"
unet_config["init_txt_model"] = False
if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
unet_config["init_txt_model"] = True
else:
unet_config["init_txt_model"] = False
unet_config["resolution"] = 64
if metadata is not None:
if metadata["is_512"] is True:
if "is_512" in metadata and metadata["metadata"]:
unet_config["resolution"] = 32
else:
unet_config["resolution"] = 64
unet_config["num_heads"] = 12
return unet_config
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit

View File

@ -494,11 +494,10 @@ class VAE:
self.downscale_ratio = 32
self.latent_channels = 16
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2
init_txt_model = False
if "txt_dec.blocks.1.16.norm1.weight" in sd:
config["init_txt_model"] = True
else:
config["init_txt_model"] = False
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(config)
init_txt_model = True
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}

View File

@ -198,7 +198,7 @@ class Trellis2Conditioning(IO.ComfyNode):
inputs=[
IO.ClipVision.Input("clip_vision_model"),
IO.Image.Input("image"),
IO.MultiCombo.Input("background_color", options=["black", "gray", "white"], default="black")
IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black")
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
@ -219,7 +219,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentTrellis2",
node_id="EmptyShapeLatentTrellis2",
category="latent/3d",
inputs=[
IO.Mesh.Input("structure_output"),
@ -242,7 +242,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentTrellis2",
node_id="EmptyTextureLatentTrellis2",
category="latent/3d",
inputs=[
IO.Mesh.Input("structure_output"),
@ -264,7 +264,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentTrellis2",
node_id="EmptyStructureLatentTrellis2",
category="latent/3d",
inputs=[
IO.Int.Input("resolution", default=3072, min=1, max=8192),