mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
Merge 987dce1db9 into 95f6652ef5
This commit is contained in:
commit
ae4f71ebfe
@ -435,9 +435,9 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
|||||||
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
org_dtype = xq.dtype
|
org_dtype = xq.dtype
|
||||||
cos = freqs_cis[0]
|
cos = freqs_cis[0].to(xq.device)
|
||||||
sin = freqs_cis[1]
|
sin = freqs_cis[1].to(xq.device)
|
||||||
nsin = freqs_cis[2]
|
nsin = freqs_cis[2].to(xq.device)
|
||||||
|
|
||||||
q_embed = (xq * cos)
|
q_embed = (xq * cos)
|
||||||
q_split = q_embed.shape[-1] // 2
|
q_split = q_embed.shape[-1] // 2
|
||||||
|
|||||||
@ -213,7 +213,10 @@ class GatedDeltaNet(nn.Module):
|
|||||||
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
|
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
|
||||||
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
|
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
|
||||||
beta = b.sigmoid()
|
beta = b.sigmoid()
|
||||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
|
|
||||||
|
A_log = comfy.model_management.cast_to_device(self.A_log, x.device, torch.float32)
|
||||||
|
dt_bias = comfy.model_management.cast_to_device(self.dt_bias, x.device, torch.float32)
|
||||||
|
g = -A_log.exp() * F.softplus(a.float() + dt_bias)
|
||||||
|
|
||||||
# Delta rule
|
# Delta rule
|
||||||
if use_recurrent:
|
if use_recurrent:
|
||||||
@ -474,9 +477,15 @@ class Qwen35VisionRotaryEmbedding(nn.Module):
|
|||||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
def forward(self, seqlen):
|
def forward(self, seqlen, device=None, dtype=None):
|
||||||
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
if device is None:
|
||||||
freqs = torch.outer(seq, self.inv_freq)
|
device = self.inv_freq.device
|
||||||
|
if dtype is None:
|
||||||
|
dtype = self.inv_freq.dtype
|
||||||
|
|
||||||
|
inv_freq = comfy.model_management.cast_to_device(self.inv_freq, device, dtype)
|
||||||
|
seq = torch.arange(seqlen, device=device, dtype=dtype)
|
||||||
|
freqs = torch.outer(seq, inv_freq)
|
||||||
return freqs
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
@ -565,12 +574,11 @@ class Qwen35VisionModel(nn.Module):
|
|||||||
])
|
])
|
||||||
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
def rot_pos_emb(self, grid_thw):
|
def rot_pos_emb(self, grid_thw, device):
|
||||||
merge_size = self.spatial_merge_size
|
merge_size = self.spatial_merge_size
|
||||||
grid_thw_list = grid_thw.tolist()
|
grid_thw_list = grid_thw.tolist()
|
||||||
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
|
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
|
||||||
freq_table = self.rotary_pos_emb(max_hw)
|
freq_table = self.rotary_pos_emb(max_hw, device=device, dtype=torch.float32)
|
||||||
device = freq_table.device
|
|
||||||
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
|
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
|
||||||
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
||||||
offset = 0
|
offset = 0
|
||||||
@ -651,7 +659,7 @@ class Qwen35VisionModel(nn.Module):
|
|||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
|
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
|
||||||
x = x + pos_embeds
|
x = x + pos_embeds
|
||||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
rotary_pos_emb = self.rot_pos_emb(grid_thw, device=x.device)
|
||||||
seq_len = x.shape[0]
|
seq_len = x.shape[0]
|
||||||
x = x.reshape(seq_len, -1)
|
x = x.reshape(seq_len, -1)
|
||||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||||
@ -659,7 +667,7 @@ class Qwen35VisionModel(nn.Module):
|
|||||||
cos = emb.cos().unsqueeze(-2)
|
cos = emb.cos().unsqueeze(-2)
|
||||||
sin = emb.sin().unsqueeze(-2)
|
sin = emb.sin().unsqueeze(-2)
|
||||||
sin_half = sin.shape[-1] // 2
|
sin_half = sin.shape[-1] // 2
|
||||||
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
|
position_embeddings = (cos.to(x.device), sin[..., :sin_half].to(x.device), -sin[..., sin_half:].to(x.device))
|
||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
).cumsum(dim=0, dtype=torch.int32)
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user