mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
texture generation works
This commit is contained in:
parent
72640888ff
commit
57b306464e
@ -53,57 +53,37 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
# TODO repalce with optimized attention
|
||||
def scaled_dot_product_attention(*args, **kwargs):
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
|
||||
q = None
|
||||
if num_all_args == 1:
|
||||
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
||||
|
||||
qkv = args[0] if len(args) > 0 else kwargs.get('qkv')
|
||||
elif num_all_args == 2:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
kv = args[1] if len(args) > 1 else kwargs['kv']
|
||||
|
||||
q = args[0] if len(args) > 0 else kwargs.get('q')
|
||||
kv = args[1] if len(args) > 1 else kwargs.get('kv')
|
||||
elif num_all_args == 3:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
k = args[1] if len(args) > 1 else kwargs['k']
|
||||
v = args[2] if len(args) > 2 else kwargs['v']
|
||||
q = args[0] if len(args) > 0 else kwargs.get('q')
|
||||
k = args[1] if len(args) > 1 else kwargs.get('k')
|
||||
v = args[2] if len(args) > 2 else kwargs.get('v')
|
||||
|
||||
if q is not None:
|
||||
heads = q
|
||||
heads = q.shape[2]
|
||||
else:
|
||||
heads = qkv
|
||||
heads = heads.shape[2]
|
||||
heads = qkv.shape[3]
|
||||
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
#out = xops.memory_efficient_attention(q, k, v)
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
if num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
elif optimized_attention.__name__ == 'attention_pytorch':
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
||||
elif optimized_attention.__name__ == 'attention_basic':
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
|
||||
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -788,7 +788,10 @@ class Trellis2(nn.Module):
|
||||
sigmas = transformer_options.get("sigmas")[0].item()
|
||||
if sigmas < 1.00001:
|
||||
timestep *= 1000.0
|
||||
cond = context.chunk(2)[1]
|
||||
if context.size(0) > 1:
|
||||
cond = context.chunk(2)[1]
|
||||
else:
|
||||
cond = context
|
||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||
|
||||
@ -836,7 +839,7 @@ class Trellis2(nn.Module):
|
||||
if slat is None:
|
||||
raise ValueError("shape_slat can't be None")
|
||||
|
||||
base_slat_feats = slat.feats[:N]
|
||||
base_slat_feats = slat[:N]
|
||||
slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device)
|
||||
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1))
|
||||
out = self.shape2txt(x_st, t_eval, c_eval)
|
||||
|
||||
@ -72,7 +72,13 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
|
||||
nearest_idx = torch.from_numpy(nearest_idx_np).long()
|
||||
v_colors = voxel_colors[nearest_idx]
|
||||
|
||||
final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0)
|
||||
# to [0, 1]
|
||||
srgb_colors = (v_colors * 0.5 + 0.5).clamp(0, 1)
|
||||
|
||||
# to Linear RGB (required for GLTF)
|
||||
linear_colors = torch.pow(srgb_colors, 2.2)
|
||||
|
||||
final_colors = linear_colors.unsqueeze(0)
|
||||
|
||||
out_mesh = copy.deepcopy(mesh)
|
||||
out_mesh.colors = final_colors
|
||||
|
||||
Loading…
Reference in New Issue
Block a user