mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
progressing
This commit is contained in:
parent
30db88dd64
commit
2a27c3b417
@ -226,8 +226,11 @@ class DINOv3ViTLayer(nn.Module):
|
||||
class DINOv3ViTModel(nn.Module):
|
||||
def __init__(self, config, dtype, device, operations):
|
||||
super().__init__()
|
||||
if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False):
|
||||
use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True)
|
||||
if dtype == torch.float16 and use_bf16:
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == torch.float16 and not use_bf16:
|
||||
dtype = torch.float32
|
||||
num_hidden_layers = config["num_hidden_layers"]
|
||||
hidden_size = config["hidden_size"]
|
||||
num_attention_heads = config["num_attention_heads"]
|
||||
|
||||
@ -10,8 +10,8 @@ import comfy.ops
|
||||
def var_attn_arg(kwargs):
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
||||
max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q)
|
||||
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
|
||||
@ -183,6 +183,7 @@ def calc_window_partition(
|
||||
|
||||
|
||||
def sparse_scaled_dot_product_attention(*args, **kwargs):
|
||||
q=None
|
||||
arg_names_dict = {
|
||||
1: ['qkv'],
|
||||
2: ['q', 'kv'],
|
||||
@ -250,6 +251,12 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
|
||||
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
|
||||
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
|
||||
|
||||
# TODO: change
|
||||
if q is not None:
|
||||
heads = q
|
||||
else:
|
||||
heads = qkv
|
||||
heads = heads.shape[2]
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
if 'xops' not in globals():
|
||||
import xformers.ops as xops
|
||||
@ -279,11 +286,15 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
|
||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args in [2, 3]:
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
else:
|
||||
cu_seqlens_kv = cu_seqlens_q
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=1)
|
||||
out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen),
|
||||
skip_reshape=True, skip_output_reshape=True)
|
||||
|
||||
if s is not None:
|
||||
return s.replace(out)
|
||||
|
||||
@ -232,6 +232,8 @@ class SparseMultiHeadAttention(nn.Module):
|
||||
else:
|
||||
q = self._linear(self.to_q, x)
|
||||
q = self._reshape_chs(q, (self.num_heads, -1))
|
||||
dtype = next(self.to_kv.parameters()).dtype
|
||||
context = context.to(dtype)
|
||||
kv = self._linear(self.to_kv, context)
|
||||
kv = self._fused_pre(kv, num_fused=2)
|
||||
if self.qk_rms_norm:
|
||||
@ -760,15 +762,13 @@ class Trellis2(nn.Module):
|
||||
self.guidance_interval_txt = [0.6, 0.9]
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
# FIXME: should find a way to distinguish between 512/1024 models
|
||||
# currently assumes 1024
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
embeds = kwargs.get("embeds")
|
||||
if embeds is None:
|
||||
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||
#_, cond = context.chunk(2) # TODO
|
||||
cond = embeds.chunk(2)[0]
|
||||
context = torch.cat([torch.zeros_like(cond), cond])
|
||||
is_1024 = self.img2shape.resolution == 1024
|
||||
if is_1024:
|
||||
context = embeds
|
||||
coords = transformer_options.get("coords", None)
|
||||
mode = transformer_options.get("generation_mode", "structure_generation")
|
||||
if coords is not None:
|
||||
|
||||
@ -2,6 +2,7 @@ from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
import torch.nn.functional as TF
|
||||
import comfy.model_management
|
||||
from comfy.utils import ProgressBar
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -250,7 +251,7 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
|
||||
embeds = conditioning["cond_1024"] # should add that
|
||||
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
||||
negative = [[conditioning["neg_cond"], {"embeds": embeds}]]
|
||||
negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]]
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
@ -512,15 +513,23 @@ class PostProcessMesh(IO.ComfyNode):
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
||||
bar = ProgressBar(2)
|
||||
mesh = copy.deepcopy(mesh)
|
||||
verts, faces = mesh.vertices, mesh.faces
|
||||
|
||||
if fill_holes_perimeter != 0.0:
|
||||
verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter)
|
||||
bar.update(1)
|
||||
else:
|
||||
bar.update(1)
|
||||
|
||||
if simplify != 0:
|
||||
verts, faces = simplify_fn(verts, faces, simplify)
|
||||
bar.update(1)
|
||||
else:
|
||||
bar.update(1)
|
||||
|
||||
# potentially adding laplacian smoothing
|
||||
|
||||
mesh.vertices = verts
|
||||
mesh.faces = faces
|
||||
|
||||
Loading…
Reference in New Issue
Block a user