mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
progressing
This commit is contained in:
parent
30db88dd64
commit
2a27c3b417
@ -226,8 +226,11 @@ class DINOv3ViTLayer(nn.Module):
|
|||||||
class DINOv3ViTModel(nn.Module):
|
class DINOv3ViTModel(nn.Module):
|
||||||
def __init__(self, config, dtype, device, operations):
|
def __init__(self, config, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False):
|
use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True)
|
||||||
|
if dtype == torch.float16 and use_bf16:
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
elif dtype == torch.float16 and not use_bf16:
|
||||||
|
dtype = torch.float32
|
||||||
num_hidden_layers = config["num_hidden_layers"]
|
num_hidden_layers = config["num_hidden_layers"]
|
||||||
hidden_size = config["hidden_size"]
|
hidden_size = config["hidden_size"]
|
||||||
num_attention_heads = config["num_attention_heads"]
|
num_attention_heads = config["num_attention_heads"]
|
||||||
|
|||||||
@ -10,8 +10,8 @@ import comfy.ops
|
|||||||
def var_attn_arg(kwargs):
|
def var_attn_arg(kwargs):
|
||||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
||||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q)
|
max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q)
|
||||||
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||||
|
|
||||||
@ -183,6 +183,7 @@ def calc_window_partition(
|
|||||||
|
|
||||||
|
|
||||||
def sparse_scaled_dot_product_attention(*args, **kwargs):
|
def sparse_scaled_dot_product_attention(*args, **kwargs):
|
||||||
|
q=None
|
||||||
arg_names_dict = {
|
arg_names_dict = {
|
||||||
1: ['qkv'],
|
1: ['qkv'],
|
||||||
2: ['q', 'kv'],
|
2: ['q', 'kv'],
|
||||||
@ -250,6 +251,12 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
|
|||||||
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
|
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
|
||||||
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
|
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
|
||||||
|
|
||||||
|
# TODO: change
|
||||||
|
if q is not None:
|
||||||
|
heads = q
|
||||||
|
else:
|
||||||
|
heads = qkv
|
||||||
|
heads = heads.shape[2]
|
||||||
if optimized_attention.__name__ == 'attention_xformers':
|
if optimized_attention.__name__ == 'attention_xformers':
|
||||||
if 'xops' not in globals():
|
if 'xops' not in globals():
|
||||||
import xformers.ops as xops
|
import xformers.ops as xops
|
||||||
@ -279,11 +286,15 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
|
|||||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||||
if num_all_args in [2, 3]:
|
if num_all_args in [2, 3]:
|
||||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||||
|
else:
|
||||||
|
cu_seqlens_kv = cu_seqlens_q
|
||||||
if num_all_args == 1:
|
if num_all_args == 1:
|
||||||
q, k, v = qkv.unbind(dim=1)
|
q, k, v = qkv.unbind(dim=1)
|
||||||
elif num_all_args == 2:
|
elif num_all_args == 2:
|
||||||
k, v = kv.unbind(dim=1)
|
k, v = kv.unbind(dim=1)
|
||||||
out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen),
|
||||||
|
skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
if s is not None:
|
if s is not None:
|
||||||
return s.replace(out)
|
return s.replace(out)
|
||||||
|
|||||||
@ -232,6 +232,8 @@ class SparseMultiHeadAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
q = self._linear(self.to_q, x)
|
q = self._linear(self.to_q, x)
|
||||||
q = self._reshape_chs(q, (self.num_heads, -1))
|
q = self._reshape_chs(q, (self.num_heads, -1))
|
||||||
|
dtype = next(self.to_kv.parameters()).dtype
|
||||||
|
context = context.to(dtype)
|
||||||
kv = self._linear(self.to_kv, context)
|
kv = self._linear(self.to_kv, context)
|
||||||
kv = self._fused_pre(kv, num_fused=2)
|
kv = self._fused_pre(kv, num_fused=2)
|
||||||
if self.qk_rms_norm:
|
if self.qk_rms_norm:
|
||||||
@ -760,15 +762,13 @@ class Trellis2(nn.Module):
|
|||||||
self.guidance_interval_txt = [0.6, 0.9]
|
self.guidance_interval_txt = [0.6, 0.9]
|
||||||
|
|
||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
# FIXME: should find a way to distinguish between 512/1024 models
|
|
||||||
# currently assumes 1024
|
|
||||||
transformer_options = kwargs.get("transformer_options", {})
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
if embeds is None:
|
if embeds is None:
|
||||||
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||||
#_, cond = context.chunk(2) # TODO
|
is_1024 = self.img2shape.resolution == 1024
|
||||||
cond = embeds.chunk(2)[0]
|
if is_1024:
|
||||||
context = torch.cat([torch.zeros_like(cond), cond])
|
context = embeds
|
||||||
coords = transformer_options.get("coords", None)
|
coords = transformer_options.get("coords", None)
|
||||||
mode = transformer_options.get("generation_mode", "structure_generation")
|
mode = transformer_options.get("generation_mode", "structure_generation")
|
||||||
if coords is not None:
|
if coords is not None:
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from typing_extensions import override
|
|||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
import torch.nn.functional as TF
|
import torch.nn.functional as TF
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -250,7 +251,7 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
|
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
|
||||||
embeds = conditioning["cond_1024"] # should add that
|
embeds = conditioning["cond_1024"] # should add that
|
||||||
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
||||||
negative = [[conditioning["neg_cond"], {"embeds": embeds}]]
|
negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]]
|
||||||
return IO.NodeOutput(positive, negative)
|
return IO.NodeOutput(positive, negative)
|
||||||
|
|
||||||
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||||
@ -512,15 +513,23 @@ class PostProcessMesh(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
||||||
|
bar = ProgressBar(2)
|
||||||
mesh = copy.deepcopy(mesh)
|
mesh = copy.deepcopy(mesh)
|
||||||
verts, faces = mesh.vertices, mesh.faces
|
verts, faces = mesh.vertices, mesh.faces
|
||||||
|
|
||||||
if fill_holes_perimeter != 0.0:
|
if fill_holes_perimeter != 0.0:
|
||||||
verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter)
|
verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter)
|
||||||
|
bar.update(1)
|
||||||
|
else:
|
||||||
|
bar.update(1)
|
||||||
|
|
||||||
if simplify != 0:
|
if simplify != 0:
|
||||||
verts, faces = simplify_fn(verts, faces, simplify)
|
verts, faces = simplify_fn(verts, faces, simplify)
|
||||||
|
bar.update(1)
|
||||||
|
else:
|
||||||
|
bar.update(1)
|
||||||
|
|
||||||
|
# potentially adding laplacian smoothing
|
||||||
|
|
||||||
mesh.vertices = verts
|
mesh.vertices = verts
|
||||||
mesh.faces = faces
|
mesh.faces = faces
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user