mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
small bug fixes
This commit is contained in:
parent
955c00ee38
commit
704e1b5462
@ -2,7 +2,7 @@ import torch
|
||||
import math
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from typing import Tuple, Union, List
|
||||
from vae import VarLenTensor
|
||||
from comfy.ldm.trellis2.vae import VarLenTensor
|
||||
|
||||
FLASH_ATTN_3_AVA = True
|
||||
try:
|
||||
|
||||
@ -798,9 +798,12 @@ class Trellis2(nn.Module):
|
||||
qk_rms_norm = True,
|
||||
qk_rms_norm_cross = True,
|
||||
init_txt_model=False, # for now
|
||||
dtype=None, device=None, operations=None):
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
# for some reason it passes num_heads = -1
|
||||
num_heads = 12
|
||||
args = {
|
||||
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
|
||||
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
||||
|
||||
@ -6,7 +6,7 @@ from fractions import Fraction
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
||||
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
||||
|
||||
|
||||
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
||||
@ -1457,10 +1457,9 @@ class SparseStructureDecoder(nn.Module):
|
||||
return h
|
||||
|
||||
class Vae(nn.Module):
|
||||
def __init__(self, config, operations=None):
|
||||
def __init__(self, init_txt_model, operations=None):
|
||||
super().__init__()
|
||||
operations = operations or torch.nn
|
||||
init_txt_model = config.get("init_txt_model", False)
|
||||
if init_txt_model:
|
||||
self.txt_dec = SparseUnetVaeDecoder(
|
||||
out_channels=6,
|
||||
|
||||
@ -109,15 +109,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "trellis2"
|
||||
|
||||
unet_config["init_txt_model"] = False
|
||||
if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config["init_txt_model"] = True
|
||||
else:
|
||||
unet_config["init_txt_model"] = False
|
||||
|
||||
unet_config["resolution"] = 64
|
||||
if metadata is not None:
|
||||
if metadata["is_512"] is True:
|
||||
if "is_512" in metadata and metadata["metadata"]:
|
||||
unet_config["resolution"] = 32
|
||||
else:
|
||||
unet_config["resolution"] = 64
|
||||
|
||||
unet_config["num_heads"] = 12
|
||||
return unet_config
|
||||
|
||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||
|
||||
@ -494,11 +494,10 @@ class VAE:
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2
|
||||
init_txt_model = False
|
||||
if "txt_dec.blocks.1.16.norm1.weight" in sd:
|
||||
config["init_txt_model"] = True
|
||||
else:
|
||||
config["init_txt_model"] = False
|
||||
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(config)
|
||||
init_txt_model = True
|
||||
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
|
||||
@ -198,7 +198,7 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.ClipVision.Input("clip_vision_model"),
|
||||
IO.Image.Input("image"),
|
||||
IO.MultiCombo.Input("background_color", options=["black", "gray", "white"], default="black")
|
||||
IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black")
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
@ -219,7 +219,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentTrellis2",
|
||||
node_id="EmptyShapeLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Mesh.Input("structure_output"),
|
||||
@ -242,7 +242,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentTrellis2",
|
||||
node_id="EmptyTextureLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Mesh.Input("structure_output"),
|
||||
@ -264,7 +264,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentTrellis2",
|
||||
node_id="EmptyStructureLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("resolution", default=3072, min=1, max=8192),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user