mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
fix for conditioning
This commit is contained in:
parent
57b306464e
commit
2cb06431e8
@ -671,7 +671,7 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
|
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
|
||||||
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
||||||
rope_phases = pos_embedder(coords)
|
rope_phases = pos_embedder(coords)
|
||||||
self.register_buffer("rope_phases", rope_phases)
|
self.register_buffer("rope_phases", rope_phases, persistent=False)
|
||||||
|
|
||||||
if pe_mode != "rope":
|
if pe_mode != "rope":
|
||||||
self.rope_phases = None
|
self.rope_phases = None
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from typing_extensions import override
|
|||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
from comfy.ldm.trellis2.vae import SparseTensor
|
from comfy.ldm.trellis2.vae import SparseTensor
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import logging
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -250,28 +249,28 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
|||||||
|
|
||||||
return IO.NodeOutput(final_coords,)
|
return IO.NodeOutput(final_coords,)
|
||||||
|
|
||||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||||
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||||
|
|
||||||
def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
||||||
model_internal = model.model
|
model_internal = model.model
|
||||||
device = comfy.model_management.intermediate_device()
|
device = comfy.model_management.intermediate_device()
|
||||||
torch_device = comfy.model_management.get_torch_device()
|
torch_device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
img_t = cropped_img_tensor.to(torch_device)
|
def prepare_tensor(pil_img, size):
|
||||||
|
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
|
||||||
def prepare_tensor(img, size):
|
img_np = np.array(resized_pil).astype(np.float32) / 255.0
|
||||||
resized = torch.nn.functional.interpolate(img, size=(size, size), mode='bicubic', align_corners=False).clamp(0.0, 1.0)
|
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
|
||||||
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||||
|
|
||||||
model_internal.image_size = 512
|
model_internal.image_size = 512
|
||||||
input_512 = prepare_tensor(img_t, 512)
|
input_512 = prepare_tensor(cropped_img_tensor, 512)
|
||||||
cond_512 = model_internal(input_512)[0]
|
cond_512 = model_internal(input_512)[0]
|
||||||
|
|
||||||
cond_1024 = None
|
cond_1024 = None
|
||||||
if include_1024:
|
if include_1024:
|
||||||
model_internal.image_size = 1024
|
model_internal.image_size = 1024
|
||||||
input_1024 = prepare_tensor(img_t, 1024)
|
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
|
||||||
cond_1024 = model_internal(input_1024)[0]
|
cond_1024 = model_internal(input_1024)[0]
|
||||||
|
|
||||||
conditioning = {
|
conditioning = {
|
||||||
@ -341,14 +340,15 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
crop_x2 = int(center_x + size // 2)
|
crop_x2 = int(center_x + size // 2)
|
||||||
crop_y2 = int(center_y + size // 2)
|
crop_y2 = int(center_y + size // 2)
|
||||||
|
|
||||||
rgba_pil = Image.fromarray(rgba_np, 'RGBA')
|
rgba_pil = Image.fromarray(rgba_np)
|
||||||
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
|
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
|
||||||
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
|
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
|
||||||
else:
|
else:
|
||||||
|
import logging
|
||||||
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
|
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
|
||||||
cropped_np = rgba_np.astype(np.float32) / 255.0
|
cropped_np = rgba_np.astype(np.float32) / 255.0
|
||||||
|
|
||||||
bg_colors = {"black": [0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]}
|
bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]}
|
||||||
bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32)
|
bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32)
|
||||||
|
|
||||||
fg = cropped_np[:, :, :3]
|
fg = cropped_np[:, :, :3]
|
||||||
@ -358,10 +358,9 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
# to match trellis2 code (quantize -> dequantize)
|
# to match trellis2 code (quantize -> dequantize)
|
||||||
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
|
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
|
||||||
|
|
||||||
cropped_img_tensor = torch.from_numpy(composite_uint8).float() / 255.0
|
cropped_pil = Image.fromarray(composite_uint8)
|
||||||
cropped_img_tensor = cropped_img_tensor.movedim(-1, 0).unsqueeze(0)
|
|
||||||
|
|
||||||
conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True)
|
conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True)
|
||||||
|
|
||||||
embeds = conditioning["cond_1024"]
|
embeds = conditioning["cond_1024"]
|
||||||
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user