fix for conditioning

This commit is contained in:
Yousef Rafat 2026-04-07 23:02:33 +02:00
parent 57b306464e
commit 2cb06431e8
2 changed files with 15 additions and 16 deletions

View File

@ -671,7 +671,7 @@ class SparseStructureFlowModel(nn.Module):
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
rope_phases = pos_embedder(coords)
self.register_buffer("rope_phases", rope_phases)
self.register_buffer("rope_phases", rope_phases, persistent=False)
if pe_mode != "rope":
self.rope_phases = None

View File

@ -2,7 +2,6 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types
from comfy.ldm.trellis2.vae import SparseTensor
import comfy.model_management
import logging
from PIL import Image
import numpy as np
import torch
@ -250,28 +249,28 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
return IO.NodeOutput(final_coords,)
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
def run_conditioning(model, cropped_img_tensor, include_1024=True):
model_internal = model.model
device = comfy.model_management.intermediate_device()
torch_device = comfy.model_management.get_torch_device()
img_t = cropped_img_tensor.to(torch_device)
def prepare_tensor(img, size):
resized = torch.nn.functional.interpolate(img, size=(size, size), mode='bicubic', align_corners=False).clamp(0.0, 1.0)
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
def prepare_tensor(pil_img, size):
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
img_np = np.array(resized_pil).astype(np.float32) / 255.0
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
model_internal.image_size = 512
input_512 = prepare_tensor(img_t, 512)
input_512 = prepare_tensor(cropped_img_tensor, 512)
cond_512 = model_internal(input_512)[0]
cond_1024 = None
if include_1024:
model_internal.image_size = 1024
input_1024 = prepare_tensor(img_t, 1024)
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
cond_1024 = model_internal(input_1024)[0]
conditioning = {
@ -341,14 +340,15 @@ class Trellis2Conditioning(IO.ComfyNode):
crop_x2 = int(center_x + size // 2)
crop_y2 = int(center_y + size // 2)
rgba_pil = Image.fromarray(rgba_np, 'RGBA')
rgba_pil = Image.fromarray(rgba_np)
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
else:
import logging
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
cropped_np = rgba_np.astype(np.float32) / 255.0
bg_colors = {"black": [0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]}
bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]}
bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32)
fg = cropped_np[:, :, :3]
@ -358,10 +358,9 @@ class Trellis2Conditioning(IO.ComfyNode):
# to match trellis2 code (quantize -> dequantize)
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
cropped_img_tensor = torch.from_numpy(composite_uint8).float() / 255.0
cropped_img_tensor = cropped_img_tensor.movedim(-1, 0).unsqueeze(0)
cropped_pil = Image.fromarray(composite_uint8)
conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True)
conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True)
embeds = conditioning["cond_1024"]
positive = [[conditioning["cond_512"], {"embeds": embeds}]]