mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
512 lines
21 KiB
Python
512 lines
21 KiB
Python
"""Lens denoising transformer (DiT)"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import comfy.ldm.flux.layers
|
|
import comfy.patcher_extension
|
|
from comfy.ldm.flux.layers import EmbedND
|
|
from comfy.ldm.flux.math import apply_rope
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
|
|
|
|
def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
|
|
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
|
|
|
|
|
|
def _lens_position_ids(
|
|
frame: int, height: int, width: int, text_seq_len: int,
|
|
scale_rope: bool = True, device=None,
|
|
) -> torch.Tensor:
|
|
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
|
|
|
|
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
|
|
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
|
|
caller adds a batch dim for ``EmbedND``.
|
|
"""
|
|
if scale_rope:
|
|
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
|
|
torch.arange(0, height // 2, device=device)])
|
|
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
|
|
torch.arange(0, width // 2, device=device)])
|
|
text_start = max(height // 2, width // 2)
|
|
else:
|
|
h_pos = torch.arange(height, device=device)
|
|
w_pos = torch.arange(width, device=device)
|
|
text_start = max(height, width)
|
|
|
|
f_pos = torch.arange(frame, device=device)
|
|
img_ids = torch.zeros(frame, height, width, 3, device=device)
|
|
img_ids[..., 0] = f_pos[:, None, None]
|
|
img_ids[..., 1] = h_pos[None, :, None]
|
|
img_ids[..., 2] = w_pos[None, None, :]
|
|
img_ids = img_ids.reshape(-1, 3)
|
|
|
|
# Text positions replicate across all 3 axes (matches original packing).
|
|
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
|
|
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
|
|
|
|
return torch.cat([img_ids, txt_ids], dim=0)
|
|
|
|
|
|
class _TimestepEmbedder(nn.Module):
|
|
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
|
|
super().__init__()
|
|
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
|
|
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.linear_1(x)
|
|
x = F.silu(x)
|
|
return self.linear_2(x)
|
|
|
|
|
|
class LensTimestepProjEmbeddings(nn.Module):
|
|
def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
|
|
super().__init__()
|
|
self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
proj = _lens_time_proj(timestep, 256)
|
|
return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
|
|
|
|
|
|
class GateMLP(nn.Module):
|
|
"""SwiGLU MLP."""
|
|
|
|
def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
|
|
super().__init__()
|
|
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
|
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
|
|
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
|
|
|
def forward(self, x):
|
|
return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
|
|
|
|
|
|
class LensJointAttention(nn.Module):
|
|
"""Joint image+text attention with fused QKV per stream."""
|
|
|
|
def __init__(
|
|
self,
|
|
query_dim: int,
|
|
added_kv_proj_dim: int,
|
|
dim_head: int = 64,
|
|
heads: int = 8,
|
|
out_dim: Optional[int] = None,
|
|
eps: float = 1e-5,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
|
self.heads = self.inner_dim // dim_head
|
|
self.dim_head = dim_head
|
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
|
|
|
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
|
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
|
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
|
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
|
|
|
self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
|
self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
# ModuleList([Linear, Identity]) for state-dict key compatibility.
|
|
self.to_out = nn.ModuleList([
|
|
operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
|
|
nn.Identity(),
|
|
])
|
|
self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
transformer_options: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
bsz, seq_img, _ = hidden_states.shape
|
|
seq_txt = encoder_hidden_states.shape[1]
|
|
|
|
# image stream
|
|
img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
|
|
img_q, img_k, img_v = img_qkv.unbind(dim=2)
|
|
img_q = self.norm_q(img_q)
|
|
img_k = self.norm_k(img_k)
|
|
img_v = img_v.contiguous()
|
|
del img_qkv
|
|
|
|
# text stream
|
|
txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
|
|
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
|
|
txt_q = self.norm_added_q(txt_q)
|
|
txt_k = self.norm_added_k(txt_k)
|
|
|
|
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
|
|
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
|
|
del img_q, txt_q
|
|
k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
|
|
del img_k, txt_k
|
|
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
|
|
del img_v, txt_v
|
|
|
|
q, k = apply_rope(q, k, freqs_cis)
|
|
|
|
if attention_mask is not None:
|
|
expected = (bsz, 1, 1, seq_img + seq_txt)
|
|
if attention_mask.shape != expected:
|
|
raise ValueError(
|
|
f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
|
|
)
|
|
attention_mask = attention_mask.to(q.dtype)
|
|
|
|
out = optimized_attention(
|
|
q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
|
|
txt_out = self.to_add_out(out[:, seq_img:, :])
|
|
return img_out, txt_out
|
|
|
|
|
|
class LensTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_attention_heads: int,
|
|
attention_head_dim: int,
|
|
eps: float = 1e-6,
|
|
rms_norm: bool = True,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.attn = LensJointAttention(
|
|
query_dim=dim,
|
|
added_kv_proj_dim=dim,
|
|
dim_head=attention_head_dim,
|
|
heads=num_attention_heads,
|
|
out_dim=dim,
|
|
eps=1e-5,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
if rms_norm:
|
|
NormCls = operations.RMSNorm
|
|
norm_kwargs = {}
|
|
else:
|
|
NormCls = operations.LayerNorm
|
|
norm_kwargs = {"elementwise_affine": False}
|
|
|
|
mlp_hidden = int(dim / 3 * 8)
|
|
|
|
# Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
|
|
self.img_mod = nn.Sequential(
|
|
nn.SiLU(),
|
|
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
|
)
|
|
self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
|
self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
|
self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.txt_mod = nn.Sequential(
|
|
nn.SiLU(),
|
|
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
|
)
|
|
self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
|
self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
|
self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
|
|
|
|
@staticmethod
|
|
def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
temb: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
transformer_options: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
|
|
txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
|
|
|
|
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
|
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
|
|
|
img_attn, txt_attn = self.attn(
|
|
hidden_states=img_modulated,
|
|
encoder_hidden_states=txt_modulated,
|
|
freqs_cis=freqs_cis,
|
|
attention_mask=attention_mask,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
hidden_states = hidden_states + img_gate1 * img_attn
|
|
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
|
|
|
|
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
|
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
|
|
|
|
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
|
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
|
|
|
|
return encoder_hidden_states, hidden_states
|
|
|
|
|
|
class _AdaLayerNormContinuousNoAffine(nn.Module):
|
|
"""AdaLayerNormContinuous(elementwise_affine=False).
|
|
|
|
The reference uses ``scale, shift = chunk(2)`` (scale first) — opposite
|
|
to Flux's ``LastLayer``.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
|
|
dtype=None, device=None, operations=None) -> None:
|
|
super().__init__()
|
|
self.linear = operations.Linear(
|
|
conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
|
|
)
|
|
self.eps = eps
|
|
self.embedding_dim = embedding_dim
|
|
|
|
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
|
emb = self.linear(F.silu(conditioning))
|
|
scale, shift = torch.chunk(emb, 2, dim=-1)
|
|
x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
|
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
|
|
|
|
class LensTransformer2DModel(nn.Module):
|
|
"""Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
|
|
|
|
def __init__(
|
|
self,
|
|
patch_size: int = 2,
|
|
in_channels: int = 128,
|
|
out_channels: Optional[int] = 32,
|
|
num_layers: int = 48,
|
|
attention_head_dim: int = 64,
|
|
num_attention_heads: int = 24,
|
|
enc_hidden_dim: int = 2880,
|
|
axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
|
|
rms_norm: bool = True,
|
|
multi_layer_encoder_feature: bool = True,
|
|
selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
|
|
image_model=None, # unused; accepted for detection-side configs.
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.patch_size = patch_size
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels if out_channels is not None else in_channels
|
|
self.inner_dim = num_attention_heads * attention_head_dim
|
|
self.multi_layer_encoder_feature = multi_layer_encoder_feature
|
|
self.selected_layer_index = list(selected_layer_index)
|
|
self.dtype = dtype
|
|
|
|
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
|
self.time_text_embed = LensTimestepProjEmbeddings(
|
|
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
|
|
)
|
|
|
|
if self.multi_layer_encoder_feature:
|
|
self.txt_norm = nn.ModuleList(
|
|
[operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
|
|
for _ in self.selected_layer_index]
|
|
)
|
|
self.txt_in = operations.Linear(
|
|
enc_hidden_dim * len(self.selected_layer_index),
|
|
self.inner_dim, bias=True, dtype=dtype, device=device,
|
|
)
|
|
else:
|
|
self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
|
|
self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
self.transformer_blocks = nn.ModuleList([
|
|
LensTransformerBlock(
|
|
dim=self.inner_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
attention_head_dim=attention_head_dim,
|
|
eps=1e-6,
|
|
rms_norm=rms_norm,
|
|
dtype=dtype, device=device, operations=operations,
|
|
)
|
|
for _ in range(num_layers)
|
|
])
|
|
|
|
self.norm_out = _AdaLayerNormContinuousNoAffine(
|
|
self.inner_dim, self.inner_dim, eps=1e-6,
|
|
dtype=dtype, device=device, operations=operations,
|
|
)
|
|
self.proj_out = operations.Linear(
|
|
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
|
|
dtype=dtype, device=device,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
|
transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
|
|
if transformer_options is None:
|
|
transformer_options = {}
|
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
self._forward, self,
|
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
|
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
|
|
|
|
def _forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
timestep: torch.Tensor,
|
|
context: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
transformer_options: Optional[Dict[str, Any]] = None,
|
|
control: Optional[Dict[str, Any]] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
|
|
if transformer_options is None:
|
|
transformer_options = {}
|
|
transformer_options = transformer_options.copy()
|
|
patches = transformer_options.get("patches", {})
|
|
patches_replace = transformer_options.get("patches_replace", {})
|
|
blocks_replace = patches_replace.get("dit", {})
|
|
|
|
B, C, h, w = x.shape
|
|
hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
|
|
|
|
if self.multi_layer_encoder_feature:
|
|
L = len(self.selected_layer_index)
|
|
enc_dim = context.shape[-1] // L
|
|
encoder_hidden_states = list(
|
|
context.reshape(B, -1, L, enc_dim).unbind(dim=2)
|
|
)
|
|
text_seq_len = encoder_hidden_states[0].shape[1]
|
|
else:
|
|
encoder_hidden_states = context
|
|
text_seq_len = context.shape[1]
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(
|
|
(B, text_seq_len), dtype=torch.bool, device=x.device
|
|
)
|
|
|
|
img_len = h * w
|
|
joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
|
|
|
|
hidden_states = self.img_in(hidden_states)
|
|
timestep = timestep.to(hidden_states.dtype)
|
|
|
|
if self.multi_layer_encoder_feature:
|
|
normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
|
|
encoder_hidden_states = torch.cat(normed, dim=-1)
|
|
else:
|
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
|
|
|
if "post_input" in patches:
|
|
for p in patches["post_input"]:
|
|
out = p({
|
|
"img": hidden_states,
|
|
"txt": encoder_hidden_states,
|
|
"transformer_options": transformer_options,
|
|
})
|
|
hidden_states = out["img"]
|
|
encoder_hidden_states = out["txt"]
|
|
|
|
temb = self.time_text_embed(timestep, hidden_states)
|
|
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
|
|
freqs_cis = self.pos_embed(ids)
|
|
|
|
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
|
transformer_options["block_type"] = "double"
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
transformer_options["block_index"] = i
|
|
if ("double_block", i) in blocks_replace:
|
|
def block_wrap(args):
|
|
out = {}
|
|
out["txt"], out["img"] = block(
|
|
hidden_states=args["img"],
|
|
encoder_hidden_states=args["txt"],
|
|
temb=args["vec"],
|
|
freqs_cis=args["pe"],
|
|
attention_mask=args.get("attn_mask"),
|
|
transformer_options=args.get("transformer_options"),
|
|
)
|
|
return out
|
|
out = blocks_replace[("double_block", i)](
|
|
{
|
|
"img": hidden_states,
|
|
"txt": encoder_hidden_states,
|
|
"vec": temb,
|
|
"pe": freqs_cis,
|
|
"attn_mask": joint_mask,
|
|
"transformer_options": transformer_options,
|
|
},
|
|
{"original_block": block_wrap},
|
|
)
|
|
encoder_hidden_states = out["txt"]
|
|
hidden_states = out["img"]
|
|
else:
|
|
encoder_hidden_states, hidden_states = block(
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
temb=temb,
|
|
freqs_cis=freqs_cis,
|
|
attention_mask=joint_mask,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
if "double_block" in patches:
|
|
for p in patches["double_block"]:
|
|
out = p({
|
|
"img": hidden_states,
|
|
"txt": encoder_hidden_states,
|
|
"x": x,
|
|
"block_index": i,
|
|
"transformer_options": transformer_options,
|
|
})
|
|
hidden_states = out["img"]
|
|
encoder_hidden_states = out["txt"]
|
|
|
|
if control is not None:
|
|
control_i = control.get("input")
|
|
if control_i is not None and i < len(control_i):
|
|
add = control_i[i]
|
|
if add is not None:
|
|
hidden_states[:, :add.shape[1]] += add
|
|
|
|
hidden_states = self.norm_out(hidden_states, temb)
|
|
out = self.proj_out(hidden_states)
|
|
return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
|
|
|
|
@staticmethod
|
|
def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
|
|
if text_mask.dtype != torch.bool:
|
|
text_mask = text_mask.bool()
|
|
bsz = text_mask.shape[0]
|
|
img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
|
|
joint = torch.cat([img_ones, text_mask], dim=1)
|
|
additive = torch.zeros_like(joint, dtype=torch.float32)
|
|
additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
|
|
return additive[:, None, None, :]
|