diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 681666430..d95b071b5 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -53,57 +53,37 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha ) return out - -# TODO repalce with optimized attention def scaled_dot_product_attention(*args, **kwargs): num_all_args = len(args) + len(kwargs) q = None if num_all_args == 1: - qkv = args[0] if len(args) > 0 else kwargs['qkv'] - + qkv = args[0] if len(args) > 0 else kwargs.get('qkv') elif num_all_args == 2: - q = args[0] if len(args) > 0 else kwargs['q'] - kv = args[1] if len(args) > 1 else kwargs['kv'] - + q = args[0] if len(args) > 0 else kwargs.get('q') + kv = args[1] if len(args) > 1 else kwargs.get('kv') elif num_all_args == 3: - q = args[0] if len(args) > 0 else kwargs['q'] - k = args[1] if len(args) > 1 else kwargs['k'] - v = args[2] if len(args) > 2 else kwargs['v'] + q = args[0] if len(args) > 0 else kwargs.get('q') + k = args[1] if len(args) > 1 else kwargs.get('k') + v = args[2] if len(args) > 2 else kwargs.get('v') if q is not None: - heads = q + heads = q.shape[2] else: - heads = qkv - heads = heads.shape[2] + heads = qkv.shape[3] - if optimized_attention.__name__ == 'attention_xformers': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - #out = xops.memory_efficient_attention(q, k, v) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_flash': - if num_all_args == 2: - k, v = kv.unbind(dim=2) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_pytorch': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - q = q.permute(0, 2, 1, 3) # [N, H, L, C] - k = k.permute(0, 2, 1, 3) # [N, H, L, C] - v = v.permute(0, 2, 1, 3) # [N, H, L, C] - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - out = out.permute(0, 2, 1, 3) # [N, L, H, C] - elif optimized_attention.__name__ == 'attention_basic': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs) + + out = out.permute(0, 2, 1, 3) return out diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 40646f369..7c6ffdd69 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -788,7 +788,10 @@ class Trellis2(nn.Module): sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 - cond = context.chunk(2)[1] + if context.size(0) > 1: + cond = context.chunk(2)[1] + else: + cond = context shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] @@ -836,7 +839,7 @@ class Trellis2(nn.Module): if slat is None: raise ValueError("shape_slat can't be None") - base_slat_feats = slat.feats[:N] + base_slat_feats = slat[:N] slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 77e6a3add..088cdd3f1 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -72,7 +72,13 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): nearest_idx = torch.from_numpy(nearest_idx_np).long() v_colors = voxel_colors[nearest_idx] - final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0) + # to [0, 1] + srgb_colors = (v_colors * 0.5 + 0.5).clamp(0, 1) + + # to Linear RGB (required for GLTF) + linear_colors = torch.pow(srgb_colors, 2.2) + + final_colors = linear_colors.unsqueeze(0) out_mesh = copy.deepcopy(mesh) out_mesh.colors = final_colors