Fix device mismatch. 1. In apply_rope, move the RoPE frequency tensors (cos, sin, nsin) to the device of the input tensor xq. 2. In Qwen35VisionModel.forward, move position_embeddings to x.device.

This commit is contained in:
silveroxides 2026-04-05 11:57:21 +02:00
parent eb0686bbb6
commit d6756e5c97
2 changed files with 4 additions and 4 deletions

View File

@ -411,9 +411,9 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
def apply_rope(xq, xk, freqs_cis): def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype org_dtype = xq.dtype
cos = freqs_cis[0] cos = freqs_cis[0].to(xq.device)
sin = freqs_cis[1] sin = freqs_cis[1].to(xq.device)
nsin = freqs_cis[2] nsin = freqs_cis[2].to(xq.device)
q_embed = (xq * cos) q_embed = (xq * cos)
q_split = q_embed.shape[-1] // 2 q_split = q_embed.shape[-1] // 2

View File

@ -661,7 +661,7 @@ class Qwen35VisionModel(nn.Module):
cos = emb.cos().unsqueeze(-2) cos = emb.cos().unsqueeze(-2)
sin = emb.sin().unsqueeze(-2) sin = emb.sin().unsqueeze(-2)
sin_half = sin.shape[-1] // 2 sin_half = sin.shape[-1] // 2
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:]) position_embeddings = (cos.to(x.device), sin[..., :sin_half].to(x.device), -sin[..., sin_half:].to(x.device))
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32) ).cumsum(dim=0, dtype=torch.int32)