mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
small fixes
This commit is contained in:
parent
b5feac202c
commit
ee2b66a2f2
@ -828,6 +828,7 @@ class Trellis2(nn.Module):
|
|||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
mode = kwargs.get("generation_mode")
|
mode = kwargs.get("generation_mode")
|
||||||
|
coords = kwargs.get("coords")
|
||||||
transformer_options = kwargs.get("transformer_options")
|
transformer_options = kwargs.get("transformer_options")
|
||||||
sigmas = transformer_options.get("sigmas")[0].item()
|
sigmas = transformer_options.get("sigmas")[0].item()
|
||||||
if sigmas < 1.00001:
|
if sigmas < 1.00001:
|
||||||
@ -836,6 +837,9 @@ class Trellis2(nn.Module):
|
|||||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||||
|
|
||||||
|
if mode in ["shape_generation", "texture_generation"]:
|
||||||
|
x = SparseTensor(feats=x, coords=coords)
|
||||||
|
|
||||||
if mode == "shape_generation":
|
if mode == "shape_generation":
|
||||||
# TODO
|
# TODO
|
||||||
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
||||||
|
|||||||
@ -1,9 +1,7 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
import torch
|
import torch
|
||||||
from comfy.ldm.trellis2.model import SparseTensor
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_patcher
|
|
||||||
|
|
||||||
shape_slat_normalization = {
|
shape_slat_normalization = {
|
||||||
"mean": torch.tensor([
|
"mean": torch.tensor([
|
||||||
@ -205,7 +203,6 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
samples = samples.to(load_device)
|
samples = samples.to(load_device)
|
||||||
decoded = decoder(samples)>0
|
decoded = decoder(samples)>0
|
||||||
decoder.to(offload_device)
|
decoder.to(offload_device)
|
||||||
comfy.model_management.get_offload_stream
|
|
||||||
out = Types.VOXEL(decoded.squeeze(1).float())
|
out = Types.VOXEL(decoded.squeeze(1).float())
|
||||||
return IO.NodeOutput(out)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
@ -253,10 +250,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_output):
|
def execute(cls, structure_output):
|
||||||
decoded = structure_output.data
|
decoded = structure_output.data
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1)
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
latent = torch.randn(coords.shape[0], in_channels)
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation", "coords": coords})
|
||||||
|
|
||||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -275,9 +272,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_output):
|
def execute(cls, structure_output):
|
||||||
# TODO
|
# TODO
|
||||||
|
decoded = structure_output.data
|
||||||
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1]))
|
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation", "coords": coords})
|
||||||
|
|
||||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user