diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 06f2fbf74..758704108 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -411,9 +411,9 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di def apply_rope(xq, xk, freqs_cis): org_dtype = xq.dtype - cos = freqs_cis[0] - sin = freqs_cis[1] - nsin = freqs_cis[2] + cos = freqs_cis[0].to(xq.device) + sin = freqs_cis[1].to(xq.device) + nsin = freqs_cis[2].to(xq.device) q_embed = (xq * cos) q_split = q_embed.shape[-1] // 2 diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index ce9b07464..f684edd57 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -661,7 +661,7 @@ class Qwen35VisionModel(nn.Module): cos = emb.cos().unsqueeze(-2) sin = emb.sin().unsqueeze(-2) sin_half = sin.shape[-1] // 2 - position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:]) + position_embeddings = (cos.to(x.device), sin[..., :sin_half].to(x.device), -sin[..., sin_half:].to(x.device)) cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] ).cumsum(dim=0, dtype=torch.int32)