diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index f684edd57..c15a5bf7b 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -213,7 +213,10 @@ class GatedDeltaNet(nn.Module): mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim] query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1) beta = b.sigmoid() - g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float()) + + A_log = comfy.model_management.cast_to_device(self.A_log, x.device, torch.float32) + dt_bias = comfy.model_management.cast_to_device(self.dt_bias, x.device, torch.float32) + g = -A_log.exp() * F.softplus(a.float() + dt_bias) # Delta rule if use_recurrent: @@ -476,9 +479,15 @@ class Qwen35VisionRotaryEmbedding(nn.Module): inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen): - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) + def forward(self, seqlen, device=None, dtype=None): + if device is None: + device = self.inv_freq.device + if dtype is None: + dtype = self.inv_freq.dtype + + inv_freq = comfy.model_management.cast_to_device(self.inv_freq, device, dtype) + seq = torch.arange(seqlen, device=device, dtype=dtype) + freqs = torch.outer(seq, inv_freq) return freqs @@ -567,12 +576,11 @@ class Qwen35VisionModel(nn.Module): ]) self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops) - def rot_pos_emb(self, grid_thw): + def rot_pos_emb(self, grid_thw, device): merge_size = self.spatial_merge_size grid_thw_list = grid_thw.tolist() max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) - device = freq_table.device + freq_table = self.rotary_pos_emb(max_hw, device=device, dtype=torch.float32) total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list) pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) offset = 0 @@ -653,7 +661,7 @@ class Qwen35VisionModel(nn.Module): x = self.patch_embed(x) pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device) x = x + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_thw, device=x.device) seq_len = x.shape[0] x = x.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)