mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-26 18:32:35 +08:00
coderabbit 2
This commit is contained in:
parent
f3d4125e49
commit
b3da8ed4c5
@ -207,11 +207,14 @@ class TorchHashMap:
|
|||||||
|
|
||||||
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
||||||
flat = flat_keys.long()
|
flat = flat_keys.long()
|
||||||
|
if self._n == 0:
|
||||||
|
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
||||||
idx = torch.searchsorted(self.sorted_keys, flat)
|
idx = torch.searchsorted(self.sorted_keys, flat)
|
||||||
found = (idx < self._n) & (self.sorted_keys[idx] == flat)
|
idx_safe = torch.clamp(idx, max=self._n - 1)
|
||||||
|
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
|
||||||
out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
||||||
if found.any():
|
if found.any():
|
||||||
out[found] = self.sorted_vals[idx[found]]
|
out[found] = self.sorted_vals[idx_safe[found]]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from comfy.ldm.trellis2.attention import (
|
|||||||
)
|
)
|
||||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||||
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||||
import builtins
|
|
||||||
|
|
||||||
class SparseGELU(nn.GELU):
|
class SparseGELU(nn.GELU):
|
||||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||||
@ -829,19 +828,19 @@ class Trellis2(nn.Module):
|
|||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
# FIXME: should find a way to distinguish between 512/1024 models
|
# FIXME: should find a way to distinguish between 512/1024 models
|
||||||
# currently assumes 1024
|
# currently assumes 1024
|
||||||
|
transformer_options = kwargs.get("transformer_options")
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
_, cond = context.chunk(2)
|
#_, cond = context.chunk(2) # TODO
|
||||||
cond = embeds.chunk(2)[0]
|
cond = embeds.chunk(2)[0]
|
||||||
context = torch.cat([torch.zeros_like(cond), cond])
|
context = torch.cat([torch.zeros_like(cond), cond])
|
||||||
mode = getattr(builtins, "TRELLIS_MODE", "structure_generation")
|
coords = transformer_options.get("coords", None)
|
||||||
coords = getattr(builtins, "TRELLIS_COORDS", None)
|
mode = transformer_options.get("generation_mode", "structure_generation")
|
||||||
if coords is not None:
|
if coords is not None:
|
||||||
x = x.squeeze(0)
|
x = x.squeeze(0)
|
||||||
not_struct_mode = True
|
not_struct_mode = True
|
||||||
else:
|
else:
|
||||||
mode = "structure_generation"
|
mode = "structure_generation"
|
||||||
not_struct_mode = False
|
not_struct_mode = False
|
||||||
transformer_options = kwargs.get("transformer_options")
|
|
||||||
sigmas = transformer_options.get("sigmas")[0].item()
|
sigmas = transformer_options.get("sigmas")[0].item()
|
||||||
if sigmas < 1.00001:
|
if sigmas < 1.00001:
|
||||||
timestep *= 1000.0
|
timestep *= 1000.0
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import builtins
|
|
||||||
|
|
||||||
shape_slat_normalization = {
|
shape_slat_normalization = {
|
||||||
"mean": torch.tensor([
|
"mean": torch.tensor([
|
||||||
@ -258,21 +257,31 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Voxel.Input("structure_output"),
|
IO.Voxel.Input("structure_output"),
|
||||||
|
IO.Model.Input("model")
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
|
IO.Model.Output()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_output):
|
def execute(cls, structure_output, model):
|
||||||
decoded = structure_output.data.unsqueeze(1)
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = torch.randn(1, coords.shape[0], in_channels)
|
latent = torch.randn(1, coords.shape[0], in_channels)
|
||||||
builtins.TRELLIS_MODE = "shape_generation"
|
model = model.clone()
|
||||||
builtins.TRELLIS_COORDS = coords
|
if "transformer_options" not in model.model_options:
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
model.model_options = {}
|
||||||
|
else:
|
||||||
|
model.model_options = model.model_options.copy()
|
||||||
|
|
||||||
|
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||||
|
|
||||||
|
model.model_options["transformer_options"]["coords"] = coords
|
||||||
|
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||||
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
|
||||||
|
|
||||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -285,19 +294,29 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
|
IO.Model.Output()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_output):
|
def execute(cls, structure_output, model):
|
||||||
# TODO
|
# TODO
|
||||||
decoded = structure_output.data.unsqueeze(1)
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
||||||
builtins.TRELLIS_MODE = "texture_generation"
|
model = model.clone()
|
||||||
builtins.TRELLIS_COORDS = coords
|
if "transformer_options" not in model.model_options:
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
model.model_options = {}
|
||||||
|
else:
|
||||||
|
model.model_options = model.model_options.copy()
|
||||||
|
|
||||||
|
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||||
|
|
||||||
|
model.model_options["transformer_options"]["coords"] = coords
|
||||||
|
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||||
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
|
||||||
|
|
||||||
|
|
||||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user