texture generation works

This commit is contained in:
Yousef Rafat 2026-04-03 01:22:38 +02:00
parent 72640888ff
commit 57b306464e
3 changed files with 32 additions and 43 deletions

View File

@ -53,57 +53,37 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
)
return out
# TODO repalce with optimized attention
def scaled_dot_product_attention(*args, **kwargs):
num_all_args = len(args) + len(kwargs)
q = None
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs['qkv']
qkv = args[0] if len(args) > 0 else kwargs.get('qkv')
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs['q']
kv = args[1] if len(args) > 1 else kwargs['kv']
q = args[0] if len(args) > 0 else kwargs.get('q')
kv = args[1] if len(args) > 1 else kwargs.get('kv')
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs['q']
k = args[1] if len(args) > 1 else kwargs['k']
v = args[2] if len(args) > 2 else kwargs['v']
q = args[0] if len(args) > 0 else kwargs.get('q')
k = args[1] if len(args) > 1 else kwargs.get('k')
v = args[2] if len(args) > 2 else kwargs.get('v')
if q is not None:
heads = q
heads = q.shape[2]
else:
heads = qkv
heads = heads.shape[2]
heads = qkv.shape[3]
if optimized_attention.__name__ == 'attention_xformers':
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
#out = xops.memory_efficient_attention(q, k, v)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
elif optimized_attention.__name__ == 'attention_flash':
if num_all_args == 2:
k, v = kv.unbind(dim=2)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
elif optimized_attention.__name__ == 'attention_pytorch':
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
elif optimized_attention.__name__ == 'attention_basic':
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
out = out.permute(0, 2, 1, 3)
return out

View File

@ -788,7 +788,10 @@ class Trellis2(nn.Module):
sigmas = transformer_options.get("sigmas")[0].item()
if sigmas < 1.00001:
timestep *= 1000.0
cond = context.chunk(2)[1]
if context.size(0) > 1:
cond = context.chunk(2)[1]
else:
cond = context
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
@ -836,7 +839,7 @@ class Trellis2(nn.Module):
if slat is None:
raise ValueError("shape_slat can't be None")
base_slat_feats = slat.feats[:N]
base_slat_feats = slat[:N]
slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device)
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1))
out = self.shape2txt(x_st, t_eval, c_eval)

View File

@ -72,7 +72,13 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
nearest_idx = torch.from_numpy(nearest_idx_np).long()
v_colors = voxel_colors[nearest_idx]
final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0)
# to [0, 1]
srgb_colors = (v_colors * 0.5 + 0.5).clamp(0, 1)
# to Linear RGB (required for GLTF)
linear_colors = torch.pow(srgb_colors, 2.2)
final_colors = linear_colors.unsqueeze(0)
out_mesh = copy.deepcopy(mesh)
out_mesh.colors = final_colors