From 252ff69fd54ead5be014acca9aa78afeda1b530d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 26 Feb 2026 00:24:50 +0200 Subject: [PATCH 01/11] initial rfdetrv4 support --- comfy/ldm/rf_detr/rfdetr_v4.py | 700 +++++++++++++++++++++++++++++++++ comfy/model_base.py | 5 + comfy/model_detection.py | 6 + comfy/supported_models.py | 12 +- comfy_api/latest/_io.py | 5 +- comfy_extras/nodes_rfdetr.py | 154 ++++++++ nodes.py | 1 + 7 files changed, 881 insertions(+), 2 deletions(-) create mode 100644 comfy/ldm/rf_detr/rfdetr_v4.py create mode 100644 comfy_extras/nodes_rfdetr.py diff --git a/comfy/ldm/rf_detr/rfdetr_v4.py b/comfy/ldm/rf_detr/rfdetr_v4.py new file mode 100644 index 000000000..fc76e7852 --- /dev/null +++ b/comfy/ldm/rf_detr/rfdetr_v4.py @@ -0,0 +1,700 @@ +from collections import OrderedDict +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +COCO_CLASSES = [ + 'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat', + 'traffic light','fire hydrant','stop sign','parking meter','bench','bird','cat', + 'dog','horse','sheep','cow','elephant','bear','zebra','giraffe','backpack', + 'umbrella','handbag','tie','suitcase','frisbee','skis','snowboard','sports ball', + 'kite','baseball bat','baseball glove','skateboard','surfboard','tennis racket', + 'bottle','wine glass','cup','fork','knife','spoon','bowl','banana','apple', + 'sandwich','orange','broccoli','carrot','hot dog','pizza','donut','cake','chair', + 'couch','potted plant','bed','dining table','toilet','tv','laptop','mouse', + 'remote','keyboard','cell phone','microwave','oven','toaster','sink', + 'refrigerator','book','clock','vase','scissors','teddy bear','hair drier','toothbrush', +] + +# --------------------------------------------------------------------------- +# HGNetv2 backbone +# --------------------------------------------------------------------------- + +class ConvBNAct(nn.Module): + """Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem).""" + def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True): + super().__init__() + + self.conv = nn.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False) + self.bn = nn.BatchNorm2d(oc) + self.act = nn.ReLU() if use_act else nn.Identity() + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + +class LightConvBNAct(nn.Module): + def __init__(self, ic, oc, k): + super().__init__() + self.conv1 = ConvBNAct(ic, oc, 1, use_act=False) + self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True) + + def forward(self, x): + return self.conv2(self.conv1(x)) + +class _StemBlock(nn.Module): + def __init__(self, ic, mc, oc): + super().__init__() + self.stem1 = ConvBNAct(ic, mc, 3, 2) + # stem2a/stem2b use kernel=2, stride=1, no internal padding; + # padding is applied manually in forward (matching PaddlePaddle original) + self.stem2a = ConvBNAct(mc, mc//2, 2, 1) + self.stem2b = ConvBNAct(mc//2, mc, 2, 1) + self.stem3 = ConvBNAct(mc*2, mc, 3, 2) + self.stem4 = ConvBNAct(mc, oc, 1) + self.pool = nn.MaxPool2d(2, 1, ceil_mode=True) + + def forward(self, x): + x = self.stem1(x) + x = F.pad(x, (0, 1, 0, 1)) # pad before pool and stem2a + x2 = self.stem2a(x) + x2 = F.pad(x2, (0, 1, 0, 1)) # pad before stem2b + x2 = self.stem2b(x2) + x1 = self.pool(x) + return self.stem4(self.stem3(torch.cat([x1, x2], 1))) + + +class _HG_Block(nn.Module): + def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False): + super().__init__() + self.residual = residual + if light: + self.layers = nn.ModuleList( + [LightConvBNAct(ic if i == 0 else mc, mc, k) for i in range(layer_num)]) + else: + self.layers = nn.ModuleList( + [ConvBNAct(ic if i == 0 else mc, mc, k) for i in range(layer_num)]) + total = ic + layer_num * mc + + self.aggregation = nn.Sequential( + ConvBNAct(total, oc // 2, 1), + ConvBNAct(oc // 2, oc, 1)) + + def forward(self, x): + identity = x + outs = [x] + for layer in self.layers: + x = layer(x) + outs.append(x) + x = self.aggregation(torch.cat(outs, 1)) + return x + identity if self.residual else x + + +class _HG_Stage(nn.Module): + # config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num + def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6): + super().__init__() + if downsample: + self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False) + else: + self.downsample = nn.Identity() + self.blocks = nn.Sequential(*[ + _HG_Block(ic if i == 0 else oc, mc, oc, layer_num, + k=k, residual=(i != 0), light=light) + for i in range(num_blocks) + ]) + + def forward(self, x): + return self.blocks(self.downsample(x)) + + +class HGNetv2(nn.Module): + # B5 config: stem=[3,32,64], stages=[ic, mc, oc, blocks, down, light, k, layers] + _STAGE_CFGS = [[64, 64, 128, 1, False, False, 3, 6], + [128, 128, 512, 2, True, False, 3, 6], + [512, 256, 1024, 5, True, True, 5, 6], + [1024,512, 2048, 2, True, True, 5, 6]] + + def __init__(self, return_idx=(1, 2, 3)): + super().__init__() + self.stem = _StemBlock(3, 32, 64) + self.stages = nn.ModuleList([_HG_Stage(*cfg) for cfg in self._STAGE_CFGS]) + self.return_idx = list(return_idx) + self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx] + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.stem(x) + outs = [] + for i, stage in enumerate(self.stages): + x = stage(x) + if i in self.return_idx: + outs.append(x) + return outs + + +# --------------------------------------------------------------------------- +# Encoder — HybridEncoder (dfine version: RepNCSPELAN4 + SCDown PAN) +# --------------------------------------------------------------------------- + +class ConvNormLayer(nn.Module): + """Conv→act (expects pre-fused BN weights).""" + def __init__(self, ic, oc, k, s, g=1, padding=None, act=None): + super().__init__() + p = (k - 1) // 2 if padding is None else padding + self.conv = nn.Conv2d(ic, oc, k, s, p, groups=g, bias=True) + self.act = nn.SiLU() if act == 'silu' else nn.Identity() + + def forward(self, x): + return self.act(self.conv(x)) + + +class VGGBlock(nn.Module): + """Rep-VGG block (expects pre-fused weights).""" + def __init__(self, ic, oc): + super().__init__() + self.conv = nn.Conv2d(ic, oc, 3, 1, padding=1, bias=True) + self.act = nn.SiLU() + + def forward(self, x): + return self.act(self.conv(x)) + + +class CSPLayer(nn.Module): + def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu'): + super().__init__() + h = int(oc * expansion) + self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act) + self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act) + self.bottlenecks = nn.Sequential(*[VGGBlock(h, h) for _ in range(num_blocks)]) + self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act) if h != oc else nn.Identity() + + def forward(self, x): + return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x)) + + +class RepNCSPELAN4(nn.Module): + """CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder.""" + def __init__(self, c1, c2, c3, c4, n=3, act='silu'): + super().__init__() + self.c = c3 // 2 + self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act) + self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act), ConvNormLayer(c4, c4, 3, 1, act=act)) + self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act), ConvNormLayer(c4, c4, 3, 1, act=act)) + self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act) + + def forward(self, x): + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + +class SCDown(nn.Module): + """Separable conv downsampling used in HybridEncoder PAN bottom-up path.""" + def __init__(self, ic, oc, k, s): + super().__init__() + self.cv1 = ConvNormLayer(ic, oc, 1, 1) + self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc) + + def forward(self, x): + return self.cv2(self.cv1(x)) + + +class _TransformerEncoderLayer(nn.Module): + """Single AIFI encoder layer (pre- or post-norm, GELU by default).""" + def __init__(self, d_model, nhead, dim_feedforward): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.activation = nn.GELU() + + def forward(self, src, src_mask=None, pos_embed=None): + q = k = src if pos_embed is None else src + pos_embed + src2, _ = self.self_attn(q, k, value=src, attn_mask=src_mask) + src = self.norm1(src + src2) + src2 = self.linear2(self.activation(self.linear1(src))) + return self.norm2(src + src2) + + +class _TransformerEncoder(nn.Module): + """Thin wrapper so state-dict keys are encoder.0.layers.N.*""" + def __init__(self, num_layers, d_model, nhead, dim_feedforward): + super().__init__() + self.layers = nn.ModuleList([ + _TransformerEncoderLayer(d_model, nhead, dim_feedforward) + for _ in range(num_layers) + ]) + + def forward(self, src, src_mask=None, pos_embed=None): + for layer in self.layers: + src = layer(src, src_mask=src_mask, pos_embed=pos_embed) + return src + + +class HybridEncoder(nn.Module): + def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1, + pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640)): + super().__init__() + self.in_channels = list(in_channels) + self.feat_strides = list(feat_strides) + self.hidden_dim = hidden_dim + self.use_encoder_idx = list(use_encoder_idx) + self.pe_temperature = pe_temperature + self.eval_spatial_size = eval_spatial_size + self.out_channels = [hidden_dim] * len(in_channels) + self.out_strides = list(feat_strides) + + # channel projection (expects pre-fused weights) + self.input_proj = nn.ModuleList([ + nn.Sequential(OrderedDict([('conv', nn.Conv2d(ch, hidden_dim, 1, bias=True))])) + for ch in in_channels + ]) + + # AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.* + self.encoder = nn.ModuleList([ + _TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward) + for _ in range(len(use_encoder_idx)) + ]) + + nb = round(3 * depth_mult) + exp = expansion + + # top-down FPN (dfine: lateral conv has no act) + self.lateral_convs = nn.ModuleList( + [ConvNormLayer(hidden_dim, hidden_dim, 1, 1) + for _ in range(len(in_channels) - 1)]) + self.fpn_blocks = nn.ModuleList( + [RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act) + for _ in range(len(in_channels) - 1)]) + + # bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2) + self.downsample_convs = nn.ModuleList( + [nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2)) + for _ in range(len(in_channels) - 1)]) + self.pan_blocks = nn.ModuleList( + [RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act) + for _ in range(len(in_channels) - 1)]) + + # cache positional embeddings for fixed spatial size + if eval_spatial_size: + for idx in self.use_encoder_idx: + stride = self.feat_strides[idx] + pe = self._build_pe(eval_spatial_size[1] // stride, + eval_spatial_size[0] // stride, + hidden_dim, pe_temperature) + setattr(self, f'pos_embed{idx}', pe) + + @staticmethod + def _build_pe(w, h, dim=256, temp=10000.): + assert dim % 4 == 0 + gw = torch.arange(w, dtype=torch.float32) + gh = torch.arange(h, dtype=torch.float32) + gw, gh = torch.meshgrid(gw, gh, indexing='ij') + pdim = dim // 4 + omega = 1. / (temp ** (torch.arange(pdim, dtype=torch.float32) / pdim)) + ow = gw.flatten()[:, None] @ omega[None] + oh = gh.flatten()[:, None] @ omega[None] + return torch.cat([ow.sin(), ow.cos(), oh.sin(), oh.cos()], 1)[None] + + def forward(self, feats: List[torch.Tensor]) -> List[torch.Tensor]: + proj = [self.input_proj[i](f) for i, f in enumerate(feats)] + + for i, enc_idx in enumerate(self.use_encoder_idx): + h, w = proj[enc_idx].shape[2:] + src = proj[enc_idx].flatten(2).permute(0, 2, 1) + pe = getattr(self, f'pos_embed{enc_idx}').to(src.device) + for layer in self.encoder[i].layers: + src = layer(src, pos_embed=pe) + proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() + + n = len(self.in_channels) + inner = [proj[-1]] + for k in range(n - 1, 0, -1): + j = n - 1 - k + top = self.lateral_convs[j](inner[0]) + inner[0] = top + up = F.interpolate(top, scale_factor=2., mode='nearest') + inner.insert(0, self.fpn_blocks[j](torch.cat([up, proj[k - 1]], 1))) + + outs = [inner[0]] + for k in range(n - 1): + outs.append(self.pan_blocks[k]( + torch.cat([self.downsample_convs[k](outs[-1]), inner[k + 1]], 1))) + return outs + + +# --------------------------------------------------------------------------- +# Decoder — DFINETransformer +# --------------------------------------------------------------------------- + +def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.Tensor, attention_weights: torch.Tensor, num_points_list: List[int]) -> torch.Tensor: + """ + value : list of per-level tensors [bs*n_head, c, h_l, w_l] + sampling_locations: [bs, Lq, n_head, sum(pts), 2] in [0,1] + attention_weights : [bs, Lq, n_head, sum(pts)] + """ + _, c = value[0].shape[:2] # bs*n_head, c + _, Lq, n_head, _, _ = sampling_locations.shape + bs = sampling_locations.shape[0] + n_h = n_head + + grids = (2 * sampling_locations - 1) # [bs, Lq, n_head, sum_pts, 2] + grids = grids.permute(0, 2, 1, 3, 4).flatten(0, 1) # [bs*n_head, Lq, sum_pts, 2] + grids_per_lvl = grids.split(num_points_list, dim=2) # list of [bs*n_head, Lq, pts_l, 2] + + sampled = [] + for lvl, (h, w) in enumerate(spatial_shapes): + val_l = value[lvl].reshape(bs * n_h, c, h, w) + sv = F.grid_sample(val_l, grids_per_lvl[lvl], mode='bilinear', padding_mode='zeros', align_corners=False) + sampled.append(sv) # sv: [bs*n_head, c, Lq, pts_l] + + attn = attention_weights.permute(0, 2, 1, 3) # [bs, n_head, Lq, sum_pts] + attn = attn.flatten(0, 1).unsqueeze(1) # [bs*n_head, 1, Lq, sum_pts] + out = (torch.cat(sampled, -1) * attn).sum(-1) # [bs*n_head, c, Lq] + out = out.reshape(bs, n_h * c, Lq) + return out.permute(0, 2, 1) # [bs, Lq, hidden] + + +class MSDeformableAttention(nn.Module): + def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5): + super().__init__() + self.embed_dim, self.num_heads = embed_dim, num_heads + self.head_dim = embed_dim // num_heads + pts = num_points if isinstance(num_points, list) else [num_points] * num_levels + self.num_points_list = pts + self.offset_scale = offset_scale + total = num_heads * sum(pts) + self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32)) + self.sampling_offsets = nn.Linear(embed_dim, total * 2) + self.attention_weights = nn.Linear(embed_dim, total) + + def forward(self, query, ref_pts, value, spatial_shapes): + bs, Lq = query.shape[:2] + offsets = self.sampling_offsets(query).reshape( + bs, Lq, self.num_heads, sum(self.num_points_list), 2) + attn_w = F.softmax( + self.attention_weights(query).reshape( + bs, Lq, self.num_heads, sum(self.num_points_list)), -1) + scale = self.num_points_scale.to(query.dtype).unsqueeze(-1) + offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale + locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2] + return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list) + + +class Gate(nn.Module): + def __init__(self, d_model): + super().__init__() + self.gate = nn.Linear(2 * d_model, 2 * d_model) + self.norm = nn.LayerNorm(d_model) + + def forward(self, x1, x2): + g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1) + return self.norm(g1 * x1 + g2 * x2) + + +class MLP(nn.Module): + def __init__(self, in_dim, hidden_dim, out_dim, num_layers): + super().__init__() + dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim] + self.layers = nn.ModuleList(nn.Linear(dims[i], dims[i + 1]) for i in range(num_layers)) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.SiLU()(layer(x)) if i < len(self.layers) - 1 else layer(x) + return x + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.norm1 = nn.LayerNorm(d_model) + self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points) + self.gateway = Gate(d_model) + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = nn.ReLU() + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.norm3 = nn.LayerNorm(d_model) + + def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None): + q = k = target if query_pos is None else target + query_pos + t2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask) + target = self.norm1(target + t2) + t2 = self.cross_attn( + target if query_pos is None else target + query_pos, + ref_pts, value, spatial_shapes) + target = self.gateway(target, t2) + t2 = self.linear2(self.activation(self.linear1(target))) + target = self.norm3((target + t2).clamp(-65504, 65504)) + return target + + +# --------------------------------------------------------------------------- +# FDR utilities +# --------------------------------------------------------------------------- + +def weighting_function(reg_max, up, reg_scale): + """Non-uniform weighting function W(n) for FDR box regression.""" + ub1 = (abs(up[0]) * abs(reg_scale)).item() + ub2 = ub1 * 2 + step = (ub1 + 1) ** (2 / (reg_max - 2)) + left = [-(step ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)] + right = [ (step ** i) - 1 for i in range(1, reg_max // 2)] + vals = [-ub2] + left + [torch.zeros_like(up[0][None])] + right + [ub2] + return torch.tensor(vals, dtype=up.dtype, device=up.device) + + +def distance2bbox(points, distance, reg_scale): + """Decode edge-distances → cxcywh boxes.""" + rs = abs(reg_scale) + x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs) + y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs) + x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs) + y2 = points[..., 1] + (0.5 * rs + distance[..., 3]) * (points[..., 3] / rs) + x0, y0, x1_, y1_ = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1 + return torch.stack([x0, y0, x1_, y1_], -1) + + +class Integral(nn.Module): + """Sum Pr(n)·W(n) over the distribution bins.""" + def __init__(self, reg_max=32): + super().__init__() + self.reg_max = reg_max + + def forward(self, x, project): + shape = x.shape + x = F.softmax(x.reshape(-1, self.reg_max + 1), 1) + x = F.linear(x, project.to(x.device)).reshape(-1, 4) + return x.reshape(list(shape[:-1]) + [-1]) + + +class LQE(nn.Module): + """Location Quality Estimator — refines class scores using corner distribution.""" + def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32): + super().__init__() + self.k, self.reg_max = k, reg_max + self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers) + + def forward(self, scores, pred_corners): + B, L, _ = pred_corners.shape + prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max + 1), -1) + topk, _ = prob.topk(self.k, -1) + stat = torch.cat([topk, topk.mean(-1, keepdim=True)], -1) + return scores + self.reg_conf(stat.reshape(B, L, -1)) + + +class TransformerDecoder(nn.Module): + def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, + num_layers, reg_max, reg_scale, up, eval_idx=-1): + super().__init__() + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.nhead = nhead + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max + self.layers = nn.ModuleList([ + TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points) + for _ in range(self.eval_idx + 1) + ]) + self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max) for _ in range(self.eval_idx + 1)]) + self.register_buffer('project', weighting_function(reg_max, up, reg_scale)) + + def _value_op(self, memory, spatial_shapes): + """Reshape memory to per-level value tensors for deformable attention.""" + c = self.hidden_dim // self.nhead + split = [h * w for h, w in spatial_shapes] + val = memory.reshape(memory.shape[0], memory.shape[1], self.nhead, c) # memory: [bs, sum(h*w), hidden_dim] + # → [bs, n_head, c, sum_hw] + val = val.permute(0, 2, 3, 1).flatten(0, 1) # [bs*n_head, c, sum_hw] + return val.split(split, dim=-1) # list of [bs*n_head, c, h_l*w_l] + + def forward(self, target, ref_pts_unact, memory, spatial_shapes, bbox_head, score_head, query_pos_head, pre_bbox_head, integral): + val_split_flat = self._value_op(memory, spatial_shapes) # pre-split value for deformable attention + + # reshape to [bs*n_head, c, h_l, w_l] + value = [] + for lvl, (h, w) in enumerate(spatial_shapes): + v = val_split_flat[lvl] # [bs*n_head, c, h*w] + value.append(v.reshape(v.shape[0], v.shape[1], h, w)) + + ref_pts = F.sigmoid(ref_pts_unact) + output = target + output_detach = pred_corners_undetach = 0 + + dec_bboxes, dec_logits = [], [] + + for i, layer in enumerate(self.layers): + ref_input = ref_pts.unsqueeze(2) # [bs, Lq, 1, 4] + query_pos = query_pos_head(ref_pts).clamp(-10, 10) + output = layer(output, ref_input, value, spatial_shapes, query_pos=query_pos) + + if i == 0: + ref_unact = ref_pts.clamp(1e-5, 1 - 1e-5) + ref_unact = torch.log(ref_unact / (1 - ref_unact)) + pre_bboxes = F.sigmoid(pre_bbox_head(output) + ref_unact) + ref_pts_initial = pre_bboxes.detach() + + pred_corners = bbox_head[i](output + output_detach) + pred_corners_undetach + inter_ref_bbox = distance2bbox(ref_pts_initial, integral(pred_corners, self.project), self.reg_scale) + + if i == self.eval_idx: + scores = score_head[i](output) + scores = self.lqe_layers[i](scores, pred_corners) + dec_bboxes.append(inter_ref_bbox) + dec_logits.append(scores) + break + + pred_corners_undetach = pred_corners + ref_pts = inter_ref_bbox.detach() + output_detach = output.detach() + + return torch.stack(dec_bboxes), torch.stack(dec_logits) + + +class DFINETransformer(nn.Module): + def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32], + num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32, reg_scale=8.0, eval_spatial_size=(640, 640)): + super().__init__() + assert len(feat_strides) == len(feat_channels) + self.hidden_dim = hidden_dim + self.num_queries = num_queries + self.num_levels = num_levels + self.eps = eps + self.eval_spatial_size = eval_spatial_size + + self.feat_strides = list(feat_strides) + for i in range(num_levels - len(feat_strides)): + self.feat_strides.append(feat_strides[-1] * 2 ** (i + 1)) + + # input projection (expects pre-fused weights) + self.input_proj = nn.ModuleList() + for ch in feat_channels: + if ch == hidden_dim: + self.input_proj.append(nn.Identity()) + else: + self.input_proj.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d(ch, hidden_dim, 1, bias=True))]))) + in_ch = feat_channels[-1] + for i in range(num_levels - len(feat_channels)): + self.input_proj.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d(in_ch if i == 0 else hidden_dim, + hidden_dim, 3, 2, 1, bias=True))]))) + in_ch = hidden_dim + + # FDR parameters (non-trainable placeholders, set from config) + self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False) + self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False) + + pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels + self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts, + num_layers, reg_max, self.reg_scale, self.up, eval_idx) + + self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2) + self.enc_output = nn.Sequential(OrderedDict([ + ('proj', nn.Linear(hidden_dim, hidden_dim)), + ('norm', nn.LayerNorm(hidden_dim))])) + self.enc_score_head = nn.Linear(hidden_dim, num_classes) + self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) + + self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx + self.dec_score_head = nn.ModuleList( + [nn.Linear(hidden_dim, num_classes) for _ in range(self.eval_idx_ + 1)]) + self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) + self.dec_bbox_head = nn.ModuleList( + [MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3) + for _ in range(self.eval_idx_ + 1)]) + self.integral = Integral(reg_max) + + if eval_spatial_size: + # Register as buffers so checkpoint values override the freshly-computed defaults + anchors, valid_mask = self._gen_anchors() + self.register_buffer('anchors', anchors) + self.register_buffer('valid_mask', valid_mask) + + def _gen_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device='cpu'): + if spatial_shapes is None: + h0, w0 = self.eval_spatial_size + spatial_shapes = [[int(h0 / s), int(w0 / s)] for s in self.feat_strides] + anchors = [] + for lvl, (h, w) in enumerate(spatial_shapes): + gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') + gxy = (torch.stack([gx, gy], -1).float() + 0.5) / torch.tensor([w, h], dtype=dtype) + wh = torch.ones_like(gxy) * grid_size * (2. ** lvl) + anchors.append(torch.cat([gxy, wh], -1).reshape(-1, h * w, 4)) + anchors = torch.cat(anchors, 1).to(device) + valid_mask = ((anchors > self.eps) & (anchors < 1 - self.eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.full_like(anchors, float('inf'))) + return anchors, valid_mask + + def _encoder_input(self, feats: List[torch.Tensor]): + proj = [self.input_proj[i](f) for i, f in enumerate(feats)] + for i in range(len(feats), self.num_levels): + proj.append(self.input_proj[i](feats[-1] if i == len(feats) else proj[-1])) + flat, shapes = [], [] + for f in proj: + _, _, h, w = f.shape + flat.append(f.flatten(2).permute(0, 2, 1)) + shapes.append([h, w]) + return torch.cat(flat, 1), shapes + + def _decoder_input(self, memory: torch.Tensor, spatial_shapes): + anchors, valid_mask = self.anchors, self.valid_mask + if memory.shape[0] > 1: + anchors = anchors.repeat(memory.shape[0], 1, 1) + + mem = valid_mask.to(memory.dtype) * memory + out_mem = self.enc_output(mem) + logits = self.enc_score_head(out_mem) + _, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1) + idx_e = idx.unsqueeze(-1) + topk_mem = out_mem.gather(1, idx_e.expand(-1, -1, out_mem.shape[-1])) + topk_anc = anchors.gather(1, idx_e.expand(-1, -1, anchors.shape[-1])) + topk_ref = self.enc_bbox_head(topk_mem) + topk_anc + return topk_mem.detach(), topk_ref.detach() + + def forward(self, feats: List[torch.Tensor]): + memory, shapes = self._encoder_input(feats) + content, ref = self._decoder_input(memory, shapes) + out_bboxes, out_logits = self.decoder( + content, ref, memory, shapes, + self.dec_bbox_head, self.dec_score_head, + self.query_pos_head, self.pre_bbox_head, self.integral) + return {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]} + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + +class RTv4(nn.Module): + def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, num_heads=8, feat_strides=[8, 16, 32], num_head_channels=None, + image_model=None, device=None, dtype=None, operations=None): + super().__init__() + self.dtype = dtype + self.operations = operations + + self.backbone = HGNetv2() + self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff) + self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries, + feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff) + + self.num_classes = num_classes + self.num_queries = num_queries + + def forward(self, x: torch.Tensor): + return self.decoder(self.encoder(self.backbone(x))) + + def postprocess(self, outputs, orig_target_sizes: torch.Tensor): + logits = outputs['pred_logits'] + boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy') + boxes = boxes * orig_target_sizes.repeat(1, 2).unsqueeze(1) + scores = F.sigmoid(logits) + scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1) + labels = idx % self.num_classes + boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4)) + return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)] diff --git a/comfy/model_base.py b/comfy/model_base.py index 2f49578f6..7f6de3a3a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -51,6 +51,7 @@ import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 +import comfy.ldm.rf_detr.rfdetr_v4 import comfy.model_management import comfy.patcher_extension @@ -1844,3 +1845,7 @@ class Kandinsky5Image(Kandinsky5): def concat_cond(self, **kwargs): return None + +class RF_DETR_v4(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rf_detr.rfdetr_v4.RTv4) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 30ea03e8e..e0602e445 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -668,6 +668,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["audio_model"] = "ace1.5" return dit_config + if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RF-DETR_v4 + dit_config = {} + dit_config["image_model"] = "rf_detr_v4" + dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0] + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c28be1716..d876756e2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1666,7 +1666,17 @@ class ACEStep15(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) +class RF_DETR_v4(supported_models_base.BASE): + unet_config = { + "image_model": "rf_detr_v4", + } -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] + supported_inference_dtypes = [torch.float16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.RF_DETR_v4(self, device=device) + return out + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RF_DETR_v4] models += [SVD_img2vid] diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 025727071..189d7d9bc 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1224,9 +1224,10 @@ class BoundingBox(ComfyTypeIO): class Input(WidgetInput): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, - socketless: bool=True, default: dict=None, component: str=None): + socketless: bool=True, default: dict=None, component: str=None, force_input: bool=None): super().__init__(id, display_name, optional, tooltip, None, default, socketless) self.component = component + self.force_input = force_input if default is None: self.default = {"x": 0, "y": 0, "width": 512, "height": 512} @@ -1234,6 +1235,8 @@ class BoundingBox(ComfyTypeIO): d = super().as_dict() if self.component: d["component"] = self.component + if self.force_input is not None: + d["forceInput"] = self.force_input return d diff --git a/comfy_extras/nodes_rfdetr.py b/comfy_extras/nodes_rfdetr.py new file mode 100644 index 000000000..321899bdd --- /dev/null +++ b/comfy_extras/nodes_rfdetr.py @@ -0,0 +1,154 @@ +from typing_extensions import override + +import torch +from comfy.ldm.rf_detr.rfdetr_v4 import COCO_CLASSES +import comfy.model_management +import comfy.utils +from comfy_api.latest import ComfyExtension, io +from torchvision.transforms import ToPILImage, ToTensor + + +class RFDETR_detect(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="RFDETR_detect", + display_name="RF-DETR Detect", + category="detection/", + inputs=[ + io.Model.Input("model", display_name="model"), + io.Image.Input("image", display_name="image"), + io.Float.Input("threshold", display_name="threshold", default=0.5), + io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all") + ], + outputs=[ + io.BoundingBox.Output("bbox")], + ) + + @classmethod + def execute(cls, model, image, threshold, class_name) -> io.NodeOutput: + B, H, W, C = image.shape + + device = comfy.model_management.get_torch_device() + orig_size = torch.tensor([[W, H]], device=device, dtype=torch.float32).expand(B, -1) # [B, 2] as (W, H) + + image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled") + + comfy.model_management.load_model_gpu(model) + out = model.model.diffusion_model(image_in.to(device=device)) # [B, num_queries, 4+num_classes] + results = model.model.diffusion_model.postprocess(out, orig_size) # list of B dicts + + all_bbox_dicts = [] + + def _postprocess(results, threshold=0.5): + det = results[0] + keep = det['scores'] > threshold + return det['boxes'][keep].cpu(), det['labels'][keep].cpu(), det['scores'][keep].cpu() + + for i in range(B): + boxes, labels, scores = _postprocess(results[i:i+1], threshold=threshold) + + print(f'\nImage {i + 1}/{B}: Detected {len(boxes)} objects (threshold={threshold}):') + for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: -x[2].item()): + print(f' {COCO_CLASSES[label.item()]:20s} {score:.3f} ' + f'[{box[0]:.0f},{box[1]:.0f},{box[2]:.0f},{box[3]:.0f}]') + + bbox_dicts = [ + { + "x": float(box[0]), + "y": float(box[1]), + "width": float(box[2] - box[0]), + "height": float(box[3] - box[1]), + "label": COCO_CLASSES[int(label)], + "score": float(score) + } + for box, label, score in zip(boxes, labels, scores) + if class_name == "all" or COCO_CLASSES[int(label)] == class_name + ] + all_bbox_dicts.append(bbox_dicts) + + return io.NodeOutput(all_bbox_dicts) + + +class RFDETR_draw(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="RFDETR_draw", + display_name="RF-DETR Draw Detections", + category="detection/", + inputs=[ + io.Image.Input("image", display_name="image", optional=True), + io.BoundingBox.Input("bbox", display_name="bbox", force_input=True), + ], + outputs=[ + io.Image.Output("out_image", display_name="out_image"), + ], + ) + + @classmethod + def execute(cls, bbox, image=None) -> io.NodeOutput: + # Normalise bbox to a list-of-lists (one list of detections per image). + # It may arrive as: a bare dict, a flat list of dicts, or a list of lists. + B = image.shape[0] if image is not None else 1 + if isinstance(bbox, dict): + bbox = [[bbox]] * B + elif not isinstance(bbox, list) or len(bbox) == 0: + bbox = [[]] * B + elif not isinstance(bbox[0], list): + # flat list of dicts → same detections for every image + bbox = [bbox] * B + + if image is None: + image = torch.zeros((B, 3, 640, 640), dtype=torch.uint8) + + all_out_images = [] + for i in range(B): + detections = bbox[i] + if detections: + boxes = torch.tensor([[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in detections]) + labels = torch.tensor([COCO_CLASSES.index(lbl) if (lbl := d.get("label")) in COCO_CLASSES else 0 for d in detections]) + scores = torch.tensor([d.get("score", 1.0) for d in detections]) + else: + boxes = torch.zeros((0, 4)) + labels = torch.zeros((0,), dtype=torch.long) + scores = torch.zeros((0,)) + + pil_image = image[i].movedim(-1, 0) + img = ToPILImage()(pil_image) + out_image_pil = cls.draw_detections(img, boxes, labels, scores) + all_out_images.append(ToTensor()(out_image_pil).unsqueeze(0).movedim(1, -1)) + + out_images = torch.cat(all_out_images, dim=0) + return io.NodeOutput(out_images) + + @classmethod + def draw_detections(cls, img, boxes, labels, scores): + from PIL import ImageDraw, ImageFont + draw = ImageDraw.Draw(img) + try: + font = ImageFont.truetype('arial.ttf', 16) + except Exception: + font = ImageFont.load_default() + colors = [(255,0,0),(0,200,0),(0,0,255),(255,165,0),(128,0,128), + (0,255,255),(255,20,147),(100,149,237)] + for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: x[2].item()): + x1, y1, x2, y2 = box.tolist() + c = colors[label.item() % len(colors)] + draw.rectangle([x1, y1, x2, y2], outline=c, width=3) + draw.text((x1 + 2, y1 + 2), + f'{COCO_CLASSES[label.item()]} {score:.2f}', fill=c, font=font) + return img + + +class RFDETRExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + RFDETR_detect, + RFDETR_draw, + ] + + +async def comfy_entrypoint() -> RFDETRExtension: + return RFDETRExtension() diff --git a/nodes.py b/nodes.py index e2fc20d53..18972e934 100644 --- a/nodes.py +++ b/nodes.py @@ -2448,6 +2448,7 @@ async def init_builtin_extra_nodes(): "nodes_toolkit.py", "nodes_replacements.py", "nodes_nag.py", + "nodes_rfdetr.py" ] import_failed = [] From 39291724088f0f1fb5b43243498d1e379559ede1 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 27 Feb 2026 18:00:34 +0200 Subject: [PATCH 02/11] Use comfy ops --- comfy/ldm/rf_detr/rfdetr_v4.py | 198 +++++++++++++++++---------------- comfy_extras/nodes_rfdetr.py | 10 +- 2 files changed, 107 insertions(+), 101 deletions(-) diff --git a/comfy/ldm/rf_detr/rfdetr_v4.py b/comfy/ldm/rf_detr/rfdetr_v4.py index fc76e7852..4793a20ac 100644 --- a/comfy/ldm/rf_detr/rfdetr_v4.py +++ b/comfy/ldm/rf_detr/rfdetr_v4.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision +import comfy.model_management COCO_CLASSES = [ 'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat', @@ -25,35 +26,35 @@ COCO_CLASSES = [ class ConvBNAct(nn.Module): """Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem).""" - def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True): + def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True, device=None, dtype=None, operations=None): super().__init__() - self.conv = nn.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False) - self.bn = nn.BatchNorm2d(oc) + self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype) + self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype) self.act = nn.ReLU() if use_act else nn.Identity() def forward(self, x): return self.act(self.bn(self.conv(x))) class LightConvBNAct(nn.Module): - def __init__(self, ic, oc, k): + def __init__(self, ic, oc, k, device=None, dtype=None, operations=None): super().__init__() - self.conv1 = ConvBNAct(ic, oc, 1, use_act=False) - self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True) + self.conv1 = ConvBNAct(ic, oc, 1, use_act=False, device=device, dtype=dtype, operations=operations) + self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True, device=device, dtype=dtype, operations=operations) def forward(self, x): return self.conv2(self.conv1(x)) class _StemBlock(nn.Module): - def __init__(self, ic, mc, oc): + def __init__(self, ic, mc, oc, device=None, dtype=None, operations=None): super().__init__() - self.stem1 = ConvBNAct(ic, mc, 3, 2) + self.stem1 = ConvBNAct(ic, mc, 3, 2, device=device, dtype=dtype, operations=operations) # stem2a/stem2b use kernel=2, stride=1, no internal padding; # padding is applied manually in forward (matching PaddlePaddle original) - self.stem2a = ConvBNAct(mc, mc//2, 2, 1) - self.stem2b = ConvBNAct(mc//2, mc, 2, 1) - self.stem3 = ConvBNAct(mc*2, mc, 3, 2) - self.stem4 = ConvBNAct(mc, oc, 1) + self.stem2a = ConvBNAct(mc, mc//2, 2, 1, device=device, dtype=dtype, operations=operations) + self.stem2b = ConvBNAct(mc//2, mc, 2, 1, device=device, dtype=dtype, operations=operations) + self.stem3 = ConvBNAct(mc*2, mc, 3, 2, device=device, dtype=dtype, operations=operations) + self.stem4 = ConvBNAct(mc, oc, 1, device=device, dtype=dtype, operations=operations) self.pool = nn.MaxPool2d(2, 1, ceil_mode=True) def forward(self, x): @@ -67,20 +68,20 @@ class _StemBlock(nn.Module): class _HG_Block(nn.Module): - def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False): + def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False, device=None, dtype=None, operations=None): super().__init__() self.residual = residual if light: self.layers = nn.ModuleList( - [LightConvBNAct(ic if i == 0 else mc, mc, k) for i in range(layer_num)]) + [LightConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)]) else: self.layers = nn.ModuleList( - [ConvBNAct(ic if i == 0 else mc, mc, k) for i in range(layer_num)]) + [ConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)]) total = ic + layer_num * mc self.aggregation = nn.Sequential( - ConvBNAct(total, oc // 2, 1), - ConvBNAct(oc // 2, oc, 1)) + ConvBNAct(total, oc // 2, 1, device=device, dtype=dtype, operations=operations), + ConvBNAct(oc // 2, oc, 1, device=device, dtype=dtype, operations=operations)) def forward(self, x): identity = x @@ -94,15 +95,15 @@ class _HG_Block(nn.Module): class _HG_Stage(nn.Module): # config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num - def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6): + def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6, device=None, dtype=None, operations=None): super().__init__() if downsample: - self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False) + self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False, device=device, dtype=dtype, operations=operations) else: self.downsample = nn.Identity() self.blocks = nn.Sequential(*[ _HG_Block(ic if i == 0 else oc, mc, oc, layer_num, - k=k, residual=(i != 0), light=light) + k=k, residual=(i != 0), light=light, device=device, dtype=dtype, operations=operations) for i in range(num_blocks) ]) @@ -117,10 +118,10 @@ class HGNetv2(nn.Module): [512, 256, 1024, 5, True, True, 5, 6], [1024,512, 2048, 2, True, True, 5, 6]] - def __init__(self, return_idx=(1, 2, 3)): + def __init__(self, return_idx=(1, 2, 3), device=None, dtype=None, operations=None): super().__init__() - self.stem = _StemBlock(3, 32, 64) - self.stages = nn.ModuleList([_HG_Stage(*cfg) for cfg in self._STAGE_CFGS]) + self.stem = _StemBlock(3, 32, 64, device=device, dtype=dtype, operations=operations) + self.stages = nn.ModuleList([_HG_Stage(*cfg, device=device, dtype=dtype, operations=operations) for cfg in self._STAGE_CFGS]) self.return_idx = list(return_idx) self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx] @@ -140,10 +141,10 @@ class HGNetv2(nn.Module): class ConvNormLayer(nn.Module): """Conv→act (expects pre-fused BN weights).""" - def __init__(self, ic, oc, k, s, g=1, padding=None, act=None): + def __init__(self, ic, oc, k, s, g=1, padding=None, act=None, device=None, dtype=None, operations=None): super().__init__() p = (k - 1) // 2 if padding is None else padding - self.conv = nn.Conv2d(ic, oc, k, s, p, groups=g, bias=True) + self.conv = operations.Conv2d(ic, oc, k, s, p, groups=g, bias=True, device=device, dtype=dtype) self.act = nn.SiLU() if act == 'silu' else nn.Identity() def forward(self, x): @@ -152,9 +153,9 @@ class ConvNormLayer(nn.Module): class VGGBlock(nn.Module): """Rep-VGG block (expects pre-fused weights).""" - def __init__(self, ic, oc): + def __init__(self, ic, oc, device=None, dtype=None, operations=None): super().__init__() - self.conv = nn.Conv2d(ic, oc, 3, 1, padding=1, bias=True) + self.conv = operations.Conv2d(ic, oc, 3, 1, padding=1, bias=True, device=device, dtype=dtype) self.act = nn.SiLU() def forward(self, x): @@ -162,13 +163,13 @@ class VGGBlock(nn.Module): class CSPLayer(nn.Module): - def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu'): + def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu', device=None, dtype=None, operations=None): super().__init__() h = int(oc * expansion) - self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act) - self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act) - self.bottlenecks = nn.Sequential(*[VGGBlock(h, h) for _ in range(num_blocks)]) - self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act) if h != oc else nn.Identity() + self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations) + self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations) + self.bottlenecks = nn.Sequential(*[VGGBlock(h, h, device=device, dtype=dtype, operations=operations) for _ in range(num_blocks)]) + self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act, device=device, dtype=dtype, operations=operations) if h != oc else nn.Identity() def forward(self, x): return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x)) @@ -176,13 +177,13 @@ class CSPLayer(nn.Module): class RepNCSPELAN4(nn.Module): """CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder.""" - def __init__(self, c1, c2, c3, c4, n=3, act='silu'): + def __init__(self, c1, c2, c3, c4, n=3, act='silu', device=None, dtype=None, operations=None): super().__init__() self.c = c3 // 2 - self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act) - self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act), ConvNormLayer(c4, c4, 3, 1, act=act)) - self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act), ConvNormLayer(c4, c4, 3, 1, act=act)) - self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act) + self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act, device=device, dtype=dtype, operations=operations) + self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations)) + self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations)) + self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act, device=device, dtype=dtype, operations=operations) def forward(self, x): y = list(self.cv1(x).split((self.c, self.c), 1)) @@ -192,10 +193,10 @@ class RepNCSPELAN4(nn.Module): class SCDown(nn.Module): """Separable conv downsampling used in HybridEncoder PAN bottom-up path.""" - def __init__(self, ic, oc, k, s): + def __init__(self, ic, oc, k, s, device=None, dtype=None, operations=None): super().__init__() - self.cv1 = ConvNormLayer(ic, oc, 1, 1) - self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc) + self.cv1 = ConvNormLayer(ic, oc, 1, 1, device=device, dtype=dtype, operations=operations) + self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc, device=device, dtype=dtype, operations=operations) def forward(self, x): return self.cv2(self.cv1(x)) @@ -203,13 +204,13 @@ class SCDown(nn.Module): class _TransformerEncoderLayer(nn.Module): """Single AIFI encoder layer (pre- or post-norm, GELU by default).""" - def __init__(self, d_model, nhead, dim_feedforward): + def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.linear2 = nn.Linear(dim_feedforward, d_model) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype) + self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) self.activation = nn.GELU() def forward(self, src, src_mask=None, pos_embed=None): @@ -222,10 +223,10 @@ class _TransformerEncoderLayer(nn.Module): class _TransformerEncoder(nn.Module): """Thin wrapper so state-dict keys are encoder.0.layers.N.*""" - def __init__(self, num_layers, d_model, nhead, dim_feedforward): + def __init__(self, num_layers, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None): super().__init__() self.layers = nn.ModuleList([ - _TransformerEncoderLayer(d_model, nhead, dim_feedforward) + _TransformerEncoderLayer(d_model, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) ]) @@ -237,7 +238,7 @@ class _TransformerEncoder(nn.Module): class HybridEncoder(nn.Module): def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1, - pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640)): + pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640), device=None, dtype=None, operations=None): super().__init__() self.in_channels = list(in_channels) self.feat_strides = list(feat_strides) @@ -250,13 +251,13 @@ class HybridEncoder(nn.Module): # channel projection (expects pre-fused weights) self.input_proj = nn.ModuleList([ - nn.Sequential(OrderedDict([('conv', nn.Conv2d(ch, hidden_dim, 1, bias=True))])) + nn.Sequential(OrderedDict([('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))])) for ch in in_channels ]) # AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.* self.encoder = nn.ModuleList([ - _TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward) + _TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations) for _ in range(len(use_encoder_idx)) ]) @@ -265,18 +266,18 @@ class HybridEncoder(nn.Module): # top-down FPN (dfine: lateral conv has no act) self.lateral_convs = nn.ModuleList( - [ConvNormLayer(hidden_dim, hidden_dim, 1, 1) + [ConvNormLayer(hidden_dim, hidden_dim, 1, 1, device=device, dtype=dtype, operations=operations) for _ in range(len(in_channels) - 1)]) self.fpn_blocks = nn.ModuleList( - [RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act) + [RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations) for _ in range(len(in_channels) - 1)]) # bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2) self.downsample_convs = nn.ModuleList( - [nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2)) + [nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, device=device, dtype=dtype, operations=operations)) for _ in range(len(in_channels) - 1)]) self.pan_blocks = nn.ModuleList( - [RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act) + [RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations) for _ in range(len(in_channels) - 1)]) # cache positional embeddings for fixed spatial size @@ -360,7 +361,7 @@ def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.T class MSDeformableAttention(nn.Module): - def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5): + def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5, device=None, dtype=None, operations=None): super().__init__() self.embed_dim, self.num_heads = embed_dim, num_heads self.head_dim = embed_dim // num_heads @@ -369,8 +370,8 @@ class MSDeformableAttention(nn.Module): self.offset_scale = offset_scale total = num_heads * sum(pts) self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32)) - self.sampling_offsets = nn.Linear(embed_dim, total * 2) - self.attention_weights = nn.Linear(embed_dim, total) + self.sampling_offsets = operations.Linear(embed_dim, total * 2, device=device, dtype=dtype) + self.attention_weights = operations.Linear(embed_dim, total, device=device, dtype=dtype) def forward(self, query, ref_pts, value, spatial_shapes): bs, Lq = query.shape[:2] @@ -386,10 +387,10 @@ class MSDeformableAttention(nn.Module): class Gate(nn.Module): - def __init__(self, d_model): + def __init__(self, d_model, device=None, dtype=None, operations=None): super().__init__() - self.gate = nn.Linear(2 * d_model, 2 * d_model) - self.norm = nn.LayerNorm(d_model) + self.gate = operations.Linear(2 * d_model, 2 * d_model, device=device, dtype=dtype) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) def forward(self, x1, x2): g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1) @@ -397,10 +398,10 @@ class Gate(nn.Module): class MLP(nn.Module): - def __init__(self, in_dim, hidden_dim, out_dim, num_layers): + def __init__(self, in_dim, hidden_dim, out_dim, num_layers, device=None, dtype=None, operations=None): super().__init__() dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim] - self.layers = nn.ModuleList(nn.Linear(dims[i], dims[i + 1]) for i in range(num_layers)) + self.layers = nn.ModuleList(operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)) def forward(self, x): for i, layer in enumerate(self.layers): @@ -409,16 +410,16 @@ class MLP(nn.Module): class TransformerDecoderLayer(nn.Module): - def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4): + def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) - self.norm1 = nn.LayerNorm(d_model) - self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points) - self.gateway = Gate(d_model) - self.linear1 = nn.Linear(d_model, dim_feedforward) + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations) + self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations) + self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype) self.activation = nn.ReLU() - self.linear2 = nn.Linear(dim_feedforward, d_model) - self.norm3 = nn.LayerNorm(d_model) + self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None): q = k = target if query_pos is None else target + query_pos @@ -474,10 +475,10 @@ class Integral(nn.Module): class LQE(nn.Module): """Location Quality Estimator — refines class scores using corner distribution.""" - def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32): + def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32, device=None, dtype=None, operations=None): super().__init__() self.k, self.reg_max = k, reg_max - self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers) + self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers, device=device, dtype=dtype, operations=operations) def forward(self, scores, pred_corners): B, L, _ = pred_corners.shape @@ -488,8 +489,7 @@ class LQE(nn.Module): class TransformerDecoder(nn.Module): - def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, - num_layers, reg_max, reg_scale, up, eval_idx=-1): + def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, num_layers, reg_max, reg_scale, up, eval_idx=-1, device=None, dtype=None, operations=None): super().__init__() self.hidden_dim = hidden_dim self.num_layers = num_layers @@ -497,10 +497,10 @@ class TransformerDecoder(nn.Module): self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max self.layers = nn.ModuleList([ - TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points) + TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1) ]) - self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max) for _ in range(self.eval_idx + 1)]) + self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1)]) self.register_buffer('project', weighting_function(reg_max, up, reg_scale)) def _value_op(self, memory, spatial_shapes): @@ -557,7 +557,8 @@ class TransformerDecoder(nn.Module): class DFINETransformer(nn.Module): def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32], - num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32, reg_scale=8.0, eval_spatial_size=(640, 640)): + num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32, + reg_scale=8.0, eval_spatial_size=(640, 640), device=None, dtype=None, operations=None): super().__init__() assert len(feat_strides) == len(feat_channels) self.hidden_dim = hidden_dim @@ -577,12 +578,12 @@ class DFINETransformer(nn.Module): self.input_proj.append(nn.Identity()) else: self.input_proj.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(ch, hidden_dim, 1, bias=True))]))) + ('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))]))) in_ch = feat_channels[-1] for i in range(num_levels - len(feat_channels)): self.input_proj.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(in_ch if i == 0 else hidden_dim, - hidden_dim, 3, 2, 1, bias=True))]))) + ('conv', operations.Conv2d(in_ch if i == 0 else hidden_dim, + hidden_dim, 3, 2, 1, bias=True, device=device, dtype=dtype))]))) in_ch = hidden_dim # FDR parameters (non-trainable placeholders, set from config) @@ -591,21 +592,21 @@ class DFINETransformer(nn.Module): pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts, - num_layers, reg_max, self.reg_scale, self.up, eval_idx) + num_layers, reg_max, self.reg_scale, self.up, eval_idx, device=device, dtype=dtype, operations=operations) - self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2) + self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2, device=device, dtype=dtype, operations=operations) self.enc_output = nn.Sequential(OrderedDict([ - ('proj', nn.Linear(hidden_dim, hidden_dim)), - ('norm', nn.LayerNorm(hidden_dim))])) - self.enc_score_head = nn.Linear(hidden_dim, num_classes) - self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) + ('proj', operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)), + ('norm', operations.LayerNorm(hidden_dim, device=device, dtype=dtype))])) + self.enc_score_head = operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) + self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations) self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx self.dec_score_head = nn.ModuleList( - [nn.Linear(hidden_dim, num_classes) for _ in range(self.eval_idx_ + 1)]) - self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) + [operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) for _ in range(self.eval_idx_ + 1)]) + self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations) self.dec_bbox_head = nn.ModuleList( - [MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3) + [MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx_ + 1)]) self.integral = Integral(reg_max) @@ -672,21 +673,21 @@ class DFINETransformer(nn.Module): # --------------------------------------------------------------------------- class RTv4(nn.Module): - def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, num_heads=8, feat_strides=[8, 16, 32], num_head_channels=None, - image_model=None, device=None, dtype=None, operations=None): + def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, feat_strides=[8, 16, 32], device=None, dtype=None, operations=None, **kwargs): super().__init__() + self.device = device self.dtype = dtype self.operations = operations - self.backbone = HGNetv2() - self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff) + self.backbone = HGNetv2(device=device, dtype=dtype, operations=operations) + self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff, device=device, dtype=dtype, operations=operations) self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries, - feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff) + feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff, device=device, dtype=dtype, operations=operations) self.num_classes = num_classes self.num_queries = num_queries - def forward(self, x: torch.Tensor): + def _forward(self, x: torch.Tensor): return self.decoder(self.encoder(self.backbone(x))) def postprocess(self, outputs, orig_target_sizes: torch.Tensor): @@ -698,3 +699,8 @@ class RTv4(nn.Module): labels = idx % self.num_classes boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4)) return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)] + + def forward(self, x: torch.Tensor, orig_target_sizes: torch.Tensor, **kwargs): + x = comfy.model_management.cast_to_device(x, self.device, self.dtype) + outputs = self._forward(x) + return self.postprocess(outputs, orig_target_sizes) diff --git a/comfy_extras/nodes_rfdetr.py b/comfy_extras/nodes_rfdetr.py index 321899bdd..513f36d5e 100644 --- a/comfy_extras/nodes_rfdetr.py +++ b/comfy_extras/nodes_rfdetr.py @@ -29,14 +29,14 @@ class RFDETR_detect(io.ComfyNode): def execute(cls, model, image, threshold, class_name) -> io.NodeOutput: B, H, W, C = image.shape - device = comfy.model_management.get_torch_device() - orig_size = torch.tensor([[W, H]], device=device, dtype=torch.float32).expand(B, -1) # [B, 2] as (W, H) - image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled") + device = comfy.model_management.get_torch_device() + dtype = model.model.get_dtype_inference() + orig_size = torch.tensor([[W, H]], device=device, dtype=dtype).expand(B, -1) # [B, 2] as (W, H) + comfy.model_management.load_model_gpu(model) - out = model.model.diffusion_model(image_in.to(device=device)) # [B, num_queries, 4+num_classes] - results = model.model.diffusion_model.postprocess(out, orig_size) # list of B dicts + results = model.model.diffusion_model(image_in.to(device=device, dtype=dtype), orig_size) # list of B dicts all_bbox_dicts = [] From 57ba8555fe3c8e73c305f810ab625371f9b3156a Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 27 Feb 2026 18:54:32 +0200 Subject: [PATCH 03/11] Use comfy attention --- comfy/ldm/rf_detr/rfdetr_v4.py | 35 ++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/rf_detr/rfdetr_v4.py b/comfy/ldm/rf_detr/rfdetr_v4.py index 4793a20ac..f44a88279 100644 --- a/comfy/ldm/rf_detr/rfdetr_v4.py +++ b/comfy/ldm/rf_detr/rfdetr_v4.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.nn.functional as F import torchvision import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention_for_device COCO_CLASSES = [ 'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat', @@ -202,11 +203,29 @@ class SCDown(nn.Module): return self.cv2(self.cv1(x)) +class SelfAttention(nn.Module): + def __init__(self, embed_dim, num_heads, device=None, dtype=None, operations=None): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.q_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + self.k_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + self.v_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + self.out_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, query, key, value, attn_mask=None): + optimized_attention = optimized_attention_for_device(query.device, False, small_input=True) + q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value) + out = optimized_attention(q, k, v, heads=self.num_heads, mask=attn_mask) + return self.out_proj(out) + + class _TransformerEncoderLayer(nn.Module): """Single AIFI encoder layer (pre- or post-norm, GELU by default).""" def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype) + self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations) self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype) self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype) self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) @@ -215,7 +234,7 @@ class _TransformerEncoderLayer(nn.Module): def forward(self, src, src_mask=None, pos_embed=None): q = k = src if pos_embed is None else src + pos_embed - src2, _ = self.self_attn(q, k, value=src, attn_mask=src_mask) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask) src = self.norm1(src + src2) src2 = self.linear2(self.activation(self.linear1(src))) return self.norm2(src + src2) @@ -307,7 +326,7 @@ class HybridEncoder(nn.Module): for i, enc_idx in enumerate(self.use_encoder_idx): h, w = proj[enc_idx].shape[2:] src = proj[enc_idx].flatten(2).permute(0, 2, 1) - pe = getattr(self, f'pos_embed{enc_idx}').to(src.device) + pe = getattr(self, f'pos_embed{enc_idx}').to(device=src.device, dtype=src.dtype) for layer in self.encoder[i].layers: src = layer(src, pos_embed=pe) proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() @@ -412,7 +431,7 @@ class MLP(nn.Module): class TransformerDecoderLayer(nn.Module): def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype) + self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations) self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations) self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations) @@ -423,7 +442,7 @@ class TransformerDecoderLayer(nn.Module): def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None): q = k = target if query_pos is None else target + query_pos - t2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask) + t2 = self.self_attn(q, k, value=target, attn_mask=attn_mask) target = self.norm1(target + t2) t2 = self.cross_attn( target if query_pos is None else target + query_pos, @@ -451,7 +470,7 @@ def weighting_function(reg_max, up, reg_scale): def distance2bbox(points, distance, reg_scale): """Decode edge-distances → cxcywh boxes.""" - rs = abs(reg_scale) + rs = abs(reg_scale).to(dtype=points.dtype) x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs) y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs) x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs) @@ -469,7 +488,7 @@ class Integral(nn.Module): def forward(self, x, project): shape = x.shape x = F.softmax(x.reshape(-1, self.reg_max + 1), 1) - x = F.linear(x, project.to(x.device)).reshape(-1, 4) + x = F.linear(x, project.to(device=x.device, dtype=x.dtype)).reshape(-1, 4) return x.reshape(list(shape[:-1]) + [-1]) @@ -644,7 +663,7 @@ class DFINETransformer(nn.Module): return torch.cat(flat, 1), shapes def _decoder_input(self, memory: torch.Tensor, spatial_shapes): - anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = self.anchors.to(memory.dtype), self.valid_mask if memory.shape[0] > 1: anchors = anchors.repeat(memory.shape[0], 1, 1) From 0c66a69c91b6bd056d0143b48ec06ccea9d4be7d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:35:57 +0200 Subject: [PATCH 04/11] Node adjustments --- comfy_extras/nodes_rfdetr.py | 77 +++++++++++++++++++----------------- comfy_extras/nodes_sdpose.py | 17 +++++++- 2 files changed, 55 insertions(+), 39 deletions(-) diff --git a/comfy_extras/nodes_rfdetr.py b/comfy_extras/nodes_rfdetr.py index 513f36d5e..0b1e22d24 100644 --- a/comfy_extras/nodes_rfdetr.py +++ b/comfy_extras/nodes_rfdetr.py @@ -6,6 +6,7 @@ import comfy.model_management import comfy.utils from comfy_api.latest import ComfyExtension, io from torchvision.transforms import ToPILImage, ToTensor +from PIL import ImageDraw, ImageFont class RFDETR_detect(io.ComfyNode): @@ -15,18 +16,20 @@ class RFDETR_detect(io.ComfyNode): node_id="RFDETR_detect", display_name="RF-DETR Detect", category="detection/", + search_aliases=["bbox", "bounding box", "object detection", "coco"], inputs=[ io.Model.Input("model", display_name="model"), io.Image.Input("image", display_name="image"), io.Float.Input("threshold", display_name="threshold", default=0.5), - io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all") + io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all", tooltip="Filter detections by class. Set to 'all' to disable filtering."), + io.Int.Input("max_detections", display_name="max_detections", default=100, tooltip="Maximum number of detections to return per image. In order of descending confidence score."), ], outputs=[ - io.BoundingBox.Output("bbox")], + io.BoundingBox.Output("bboxes")], ) @classmethod - def execute(cls, model, image, threshold, class_name) -> io.NodeOutput: + def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput: B, H, W, C = image.shape image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled") @@ -48,12 +51,7 @@ class RFDETR_detect(io.ComfyNode): for i in range(B): boxes, labels, scores = _postprocess(results[i:i+1], threshold=threshold) - print(f'\nImage {i + 1}/{B}: Detected {len(boxes)} objects (threshold={threshold}):') - for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: -x[2].item()): - print(f' {COCO_CLASSES[label.item()]:20s} {score:.3f} ' - f'[{box[0]:.0f},{box[1]:.0f},{box[2]:.0f},{box[3]:.0f}]') - - bbox_dicts = [ + bbox_dicts = sorted([ { "x": float(box[0]), "y": float(box[1]), @@ -64,67 +62,71 @@ class RFDETR_detect(io.ComfyNode): } for box, label, score in zip(boxes, labels, scores) if class_name == "all" or COCO_CLASSES[int(label)] == class_name - ] + ], key=lambda d: d["score"], reverse=True)[:max_detections] all_bbox_dicts.append(bbox_dicts) return io.NodeOutput(all_bbox_dicts) -class RFDETR_draw(io.ComfyNode): +class DrawBBoxes(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="RFDETR_draw", - display_name="RF-DETR Draw Detections", + node_id="DrawBBoxes", + display_name="Draw BBoxes", category="detection/", + search_aliases=["bbox", "bounding box", "object detection", "rf_detr", "visualize detections", "coco"], inputs=[ - io.Image.Input("image", display_name="image", optional=True), - io.BoundingBox.Input("bbox", display_name="bbox", force_input=True), + io.Image.Input("image", optional=True), + io.BoundingBox.Input("bboxes", force_input=True), ], outputs=[ - io.Image.Output("out_image", display_name="out_image"), + io.Image.Output("out_image"), ], ) @classmethod - def execute(cls, bbox, image=None) -> io.NodeOutput: - # Normalise bbox to a list-of-lists (one list of detections per image). + def execute(cls, bboxes, image=None) -> io.NodeOutput: + # Normalise bboxes to a list-of-lists (one list of detections per image). # It may arrive as: a bare dict, a flat list of dicts, or a list of lists. B = image.shape[0] if image is not None else 1 - if isinstance(bbox, dict): - bbox = [[bbox]] * B - elif not isinstance(bbox, list) or len(bbox) == 0: - bbox = [[]] * B - elif not isinstance(bbox[0], list): - # flat list of dicts → same detections for every image - bbox = [bbox] * B + if isinstance(bboxes, dict): + bboxes = [[bboxes]] * B + elif not isinstance(bboxes, list) or len(bboxes) == 0: + bboxes = [[]] * B + elif not isinstance(bboxes[0], list): + # flat list of dicts: same detections for every image + bboxes = [bboxes] * B if image is None: - image = torch.zeros((B, 3, 640, 640), dtype=torch.uint8) + B = len(bboxes) + max_w = max((int(d["x"] + d["width"]) for frame in bboxes for d in frame), default=640) + max_h = max((int(d["y"] + d["height"]) for frame in bboxes for d in frame), default=640) + image = torch.zeros((B, max_h, max_w, 3), dtype=torch.float32) all_out_images = [] for i in range(B): - detections = bbox[i] + detections = bboxes[i] if detections: boxes = torch.tensor([[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in detections]) - labels = torch.tensor([COCO_CLASSES.index(lbl) if (lbl := d.get("label")) in COCO_CLASSES else 0 for d in detections]) + labels = [d.get("label") if d.get("label") in COCO_CLASSES else None for d in detections] scores = torch.tensor([d.get("score", 1.0) for d in detections]) else: boxes = torch.zeros((0, 4)) - labels = torch.zeros((0,), dtype=torch.long) + labels = [] scores = torch.zeros((0,)) pil_image = image[i].movedim(-1, 0) img = ToPILImage()(pil_image) - out_image_pil = cls.draw_detections(img, boxes, labels, scores) - all_out_images.append(ToTensor()(out_image_pil).unsqueeze(0).movedim(1, -1)) + if detections: + img = cls.draw_detections(img, boxes, labels, scores) + all_out_images.append(ToTensor()(img).unsqueeze(0).movedim(1, -1)) - out_images = torch.cat(all_out_images, dim=0) + out_images = torch.cat(all_out_images, dim=0).to(comfy.model_management.intermediate_device()) return io.NodeOutput(out_images) @classmethod def draw_detections(cls, img, boxes, labels, scores): - from PIL import ImageDraw, ImageFont draw = ImageDraw.Draw(img) try: font = ImageFont.truetype('arial.ttf', 16) @@ -134,10 +136,11 @@ class RFDETR_draw(io.ComfyNode): (0,255,255),(255,20,147),(100,149,237)] for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: x[2].item()): x1, y1, x2, y2 = box.tolist() - c = colors[label.item() % len(colors)] + color_idx = COCO_CLASSES.index(label) if label is not None else 0 + c = colors[color_idx % len(colors)] draw.rectangle([x1, y1, x2, y2], outline=c, width=3) - draw.text((x1 + 2, y1 + 2), - f'{COCO_CLASSES[label.item()]} {score:.2f}', fill=c, font=font) + if label is not None: + draw.text((x1 + 2, y1 + 2), f'{label} {score:.2f}', fill=c, font=font) return img @@ -146,7 +149,7 @@ class RFDETRExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ RFDETR_detect, - RFDETR_draw, + DrawBBoxes, ] diff --git a/comfy_extras/nodes_sdpose.py b/comfy_extras/nodes_sdpose.py index 71441848e..46b5fb226 100644 --- a/comfy_extras/nodes_sdpose.py +++ b/comfy_extras/nodes_sdpose.py @@ -661,6 +661,7 @@ class CropByBBoxes(io.ComfyNode): io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."), io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."), io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."), + io.Combo.Input("keep_aspect", options=["stretch", "pad"], default="stretch", tooltip="Whether to stretch the crop to fit the output size, or pad with black pixels to preserve aspect ratio."), ], outputs=[ io.Image.Output(tooltip="All crops stacked into a single image batch."), @@ -668,7 +669,7 @@ class CropByBBoxes(io.ComfyNode): ) @classmethod - def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput: + def execute(cls, image, bboxes, output_width, output_height, padding, keep_aspect="stretch") -> io.NodeOutput: total_frames = image.shape[0] img_h = image.shape[1] img_w = image.shape[2] @@ -716,7 +717,19 @@ class CropByBBoxes(io.ComfyNode): x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2 crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w) - resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled") + + if keep_aspect == "pad": + crop_h, crop_w = y2 - y1, x2 - x1 + scale = min(output_width / crop_w, output_height / crop_h) + scaled_w = int(round(crop_w * scale)) + scaled_h = int(round(crop_h * scale)) + scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled") + pad_left = (output_width - scaled_w) // 2 + pad_top = (output_height - scaled_h) // 2 + resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device) + resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled + else: # "stretch" + resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled") crops.append(resized) if not crops: From 630793cdb2f7567a7f2843bbadfe4a5ec0b4e2f4 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 3 Mar 2026 20:56:56 +0200 Subject: [PATCH 05/11] Fix typo --- .../rfdetr_v4.py => rt_detr/rtdetr_v4.py} | 0 comfy/model_base.py | 6 +++--- comfy/model_detection.py | 4 ++-- comfy/supported_models.py | 8 ++++---- .../{nodes_rfdetr.py => nodes_rtdetr.py} | 18 +++++++++--------- nodes.py | 2 +- 6 files changed, 19 insertions(+), 19 deletions(-) rename comfy/ldm/{rf_detr/rfdetr_v4.py => rt_detr/rtdetr_v4.py} (100%) rename comfy_extras/{nodes_rfdetr.py => nodes_rtdetr.py} (94%) diff --git a/comfy/ldm/rf_detr/rfdetr_v4.py b/comfy/ldm/rt_detr/rtdetr_v4.py similarity index 100% rename from comfy/ldm/rf_detr/rfdetr_v4.py rename to comfy/ldm/rt_detr/rtdetr_v4.py diff --git a/comfy/model_base.py b/comfy/model_base.py index fcc9eb9ab..2e990dd75 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -51,7 +51,7 @@ import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 -import comfy.ldm.rf_detr.rfdetr_v4 +import comfy.ldm.rt_detr.rtdetr_v4 import comfy.model_management import comfy.patcher_extension @@ -1920,6 +1920,6 @@ class Kandinsky5Image(Kandinsky5): def concat_cond(self, **kwargs): return None -class RF_DETR_v4(BaseModel): +class RT_DETR_v4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rf_detr.rfdetr_v4.RTv4) + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index aec4af00b..ada3d73a1 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -675,9 +675,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["audio_model"] = "ace1.5" return dit_config - if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RF-DETR_v4 + if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4 dit_config = {} - dit_config["image_model"] = "rf_detr_v4" + dit_config["image_model"] = "RT_DETR_v4" dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0] return dit_config diff --git a/comfy/supported_models.py b/comfy/supported_models.py index e3c658339..51db3dba6 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1719,17 +1719,17 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -class RF_DETR_v4(supported_models_base.BASE): +class RT_DETR_v4(supported_models_base.BASE): unet_config = { - "image_model": "rf_detr_v4", + "image_model": "RT_DETR_v4", } supported_inference_dtypes = [torch.float16, torch.float32] def get_model(self, state_dict, prefix="", device=None): - out = model_base.RF_DETR_v4(self, device=device) + out = model_base.RT_DETR_v4(self, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RF_DETR_v4] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_rfdetr.py b/comfy_extras/nodes_rtdetr.py similarity index 94% rename from comfy_extras/nodes_rfdetr.py rename to comfy_extras/nodes_rtdetr.py index 0b1e22d24..5e78065f7 100644 --- a/comfy_extras/nodes_rfdetr.py +++ b/comfy_extras/nodes_rtdetr.py @@ -1,7 +1,7 @@ from typing_extensions import override import torch -from comfy.ldm.rf_detr.rfdetr_v4 import COCO_CLASSES +from comfy.ldm.rt_detr.rtdetr_v4 import COCO_CLASSES import comfy.model_management import comfy.utils from comfy_api.latest import ComfyExtension, io @@ -9,12 +9,12 @@ from torchvision.transforms import ToPILImage, ToTensor from PIL import ImageDraw, ImageFont -class RFDETR_detect(io.ComfyNode): +class RTDETR_detect(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="RFDETR_detect", - display_name="RF-DETR Detect", + node_id="RTDETR_detect", + display_name="RT-DETR Detect", category="detection/", search_aliases=["bbox", "bounding box", "object detection", "coco"], inputs=[ @@ -75,7 +75,7 @@ class DrawBBoxes(io.ComfyNode): node_id="DrawBBoxes", display_name="Draw BBoxes", category="detection/", - search_aliases=["bbox", "bounding box", "object detection", "rf_detr", "visualize detections", "coco"], + search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"], inputs=[ io.Image.Input("image", optional=True), io.BoundingBox.Input("bboxes", force_input=True), @@ -144,14 +144,14 @@ class DrawBBoxes(io.ComfyNode): return img -class RFDETRExtension(ComfyExtension): +class RTDETRExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - RFDETR_detect, + RTDETR_detect, DrawBBoxes, ] -async def comfy_entrypoint() -> RFDETRExtension: - return RFDETRExtension() +async def comfy_entrypoint() -> RTDETRExtension: + return RTDETRExtension() diff --git a/nodes.py b/nodes.py index aebaf2fa3..7d8372015 100644 --- a/nodes.py +++ b/nodes.py @@ -2449,7 +2449,7 @@ async def init_builtin_extra_nodes(): "nodes_replacements.py", "nodes_nag.py", "nodes_sdpose.py", - "nodes_rfdetr.py" + "nodes_rtdetr.py" ] import_failed = [] From a888b90f0638bb9bfd48428ec93a9b5375a6b210 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 4 Mar 2026 00:33:48 +0200 Subject: [PATCH 06/11] Add BatchNorm2d to ops, support dynamic vram --- comfy/ldm/rt_detr/rtdetr_v4.py | 24 ++++++++++++------------ comfy/ops.py | 32 ++++++++++++++++++++++++++++++++ comfy_extras/nodes_rtdetr.py | 25 ++++++++++--------------- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/comfy/ldm/rt_detr/rtdetr_v4.py b/comfy/ldm/rt_detr/rtdetr_v4.py index f44a88279..3233dbdf6 100644 --- a/comfy/ldm/rt_detr/rtdetr_v4.py +++ b/comfy/ldm/rt_detr/rtdetr_v4.py @@ -31,7 +31,7 @@ class ConvBNAct(nn.Module): super().__init__() self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype) - self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype) + self.bn = operations.BatchNorm2d(oc, device=device, dtype=dtype) self.act = nn.ReLU() if use_act else nn.Identity() def forward(self, x): @@ -399,7 +399,7 @@ class MSDeformableAttention(nn.Module): attn_w = F.softmax( self.attention_weights(query).reshape( bs, Lq, self.num_heads, sum(self.num_points_list)), -1) - scale = self.num_points_scale.to(query.dtype).unsqueeze(-1) + scale = self.num_points_scale.to(query).unsqueeze(-1) offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2] return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list) @@ -662,12 +662,12 @@ class DFINETransformer(nn.Module): shapes.append([h, w]) return torch.cat(flat, 1), shapes - def _decoder_input(self, memory: torch.Tensor, spatial_shapes): - anchors, valid_mask = self.anchors.to(memory.dtype), self.valid_mask + def _decoder_input(self, memory: torch.Tensor): + anchors, valid_mask = self.anchors.to(memory), self.valid_mask if memory.shape[0] > 1: anchors = anchors.repeat(memory.shape[0], 1, 1) - mem = valid_mask.to(memory.dtype) * memory + mem = valid_mask.to(memory) * memory out_mem = self.enc_output(mem) logits = self.enc_score_head(out_mem) _, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1) @@ -679,7 +679,7 @@ class DFINETransformer(nn.Module): def forward(self, feats: List[torch.Tensor]): memory, shapes = self._encoder_input(feats) - content, ref = self._decoder_input(memory, shapes) + content, ref = self._decoder_input(memory) out_bboxes, out_logits = self.decoder( content, ref, memory, shapes, self.dec_bbox_head, self.dec_score_head, @@ -705,21 +705,21 @@ class RTv4(nn.Module): self.num_classes = num_classes self.num_queries = num_queries + self.load_device = comfy.model_management.get_torch_device() def _forward(self, x: torch.Tensor): return self.decoder(self.encoder(self.backbone(x))) - def postprocess(self, outputs, orig_target_sizes: torch.Tensor): + def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]: logits = outputs['pred_logits'] boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy') - boxes = boxes * orig_target_sizes.repeat(1, 2).unsqueeze(1) + boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1) scores = F.sigmoid(logits) scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1) labels = idx % self.num_classes boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4)) return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)] - def forward(self, x: torch.Tensor, orig_target_sizes: torch.Tensor, **kwargs): - x = comfy.model_management.cast_to_device(x, self.device, self.dtype) - outputs = self._forward(x) - return self.postprocess(outputs, orig_target_sizes) + def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs): + outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype)) + return self.postprocess(outputs, orig_size) diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..09066d9ee 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -483,6 +483,35 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + if self.weight is not None: + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + else: + weight = None + bias = None + offload_stream = None + + x = torch.nn.functional.batch_norm( + input, + comfy.model_management.cast_to(self.running_mean, dtype=input.dtype, device=input.device), + comfy.model_management.cast_to(self.running_var, dtype=input.dtype, device=input.device), + weight, bias, self.training or not self.track_running_stats, + self.momentum, self.eps + ) + uncast_bias_weight(self, weight, bias, offload_stream) + return x + + def forward(self, *args, **kwargs): + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None @@ -596,6 +625,9 @@ class manual_cast(disable_weight_init): class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True + class BatchNorm2d(disable_weight_init.BatchNorm2d): + comfy_cast_weights = True + def fp8_linear(self, input): """ diff --git a/comfy_extras/nodes_rtdetr.py b/comfy_extras/nodes_rtdetr.py index 5e78065f7..60c3c9b92 100644 --- a/comfy_extras/nodes_rtdetr.py +++ b/comfy_extras/nodes_rtdetr.py @@ -34,24 +34,18 @@ class RTDETR_detect(io.ComfyNode): image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled") - device = comfy.model_management.get_torch_device() - dtype = model.model.get_dtype_inference() - orig_size = torch.tensor([[W, H]], device=device, dtype=dtype).expand(B, -1) # [B, 2] as (W, H) - comfy.model_management.load_model_gpu(model) - results = model.model.diffusion_model(image_in.to(device=device, dtype=dtype), orig_size) # list of B dicts + results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts all_bbox_dicts = [] - def _postprocess(results, threshold=0.5): - det = results[0] - keep = det['scores'] > threshold - return det['boxes'][keep].cpu(), det['labels'][keep].cpu(), det['scores'][keep].cpu() + for det in results: + keep = det['scores'] > threshold + boxes = det['boxes'][keep].cpu() + labels = det['labels'][keep].cpu() + scores = det['scores'][keep].cpu() - for i in range(B): - boxes, labels, scores = _postprocess(results[i:i+1], threshold=threshold) - - bbox_dicts = sorted([ + bbox_dicts = [ { "x": float(box[0]), "y": float(box[1]), @@ -62,8 +56,9 @@ class RTDETR_detect(io.ComfyNode): } for box, label, score in zip(boxes, labels, scores) if class_name == "all" or COCO_CLASSES[int(label)] == class_name - ], key=lambda d: d["score"], reverse=True)[:max_detections] - all_bbox_dicts.append(bbox_dicts) + ] + bbox_dicts.sort(key=lambda d: d["score"], reverse=True) + all_bbox_dicts.append(bbox_dicts[:max_detections]) return io.NodeOutput(all_bbox_dicts) From 7c93167ca77d5aed5f76992040afe4b8f46f6518 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 4 Mar 2026 00:36:30 +0200 Subject: [PATCH 07/11] formatting --- comfy/supported_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 81d45203f..9bc628bc2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1702,6 +1702,7 @@ class ACEStep15(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) + class LongCatImage(supported_models_base.BASE): unet_config = { "image_model": "flux", @@ -1733,6 +1734,7 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) + class RT_DETR_v4(supported_models_base.BASE): unet_config = { "image_model": "RT_DETR_v4", From 69c83a60dbbcc837c985d34ee5aa23b606d6d156 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 4 Mar 2026 01:30:40 +0200 Subject: [PATCH 08/11] Small fixes --- comfy/supported_models.py | 3 +++ comfy_extras/nodes_rtdetr.py | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9bc628bc2..6606d27e4 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1746,6 +1746,9 @@ class RT_DETR_v4(supported_models_base.BASE): out = model_base.RT_DETR_v4(self, device=device) return out + def clip_target(self, state_dict={}): + return None + models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_rtdetr.py b/comfy_extras/nodes_rtdetr.py index 60c3c9b92..61307e268 100644 --- a/comfy_extras/nodes_rtdetr.py +++ b/comfy_extras/nodes_rtdetr.py @@ -82,16 +82,18 @@ class DrawBBoxes(io.ComfyNode): @classmethod def execute(cls, bboxes, image=None) -> io.NodeOutput: - # Normalise bboxes to a list-of-lists (one list of detections per image). - # It may arrive as: a bare dict, a flat list of dicts, or a list of lists. + # Normalise to list[list[dict]], then fit to batch size B. B = image.shape[0] if image is not None else 1 if isinstance(bboxes, dict): - bboxes = [[bboxes]] * B - elif not isinstance(bboxes, list) or len(bboxes) == 0: - bboxes = [[]] * B - elif not isinstance(bboxes[0], list): - # flat list of dicts: same detections for every image - bboxes = [bboxes] * B + bboxes = [[bboxes]] + elif not isinstance(bboxes, list) or not bboxes: + bboxes = [[]] + elif isinstance(bboxes[0], dict): + bboxes = [bboxes] # flat list → same detections for every image + + if len(bboxes) == 1: + bboxes = bboxes * B + bboxes = (bboxes + [[]] * B)[:B] if image is None: B = len(bboxes) From 7250d013dfac0269b16d30a31b77242a4b7b3672 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:20:42 +0200 Subject: [PATCH 09/11] Not necessary for this model anymore --- comfy/ldm/rt_detr/rtdetr_v4.py | 2 +- comfy/ops.py | 31 ------------------------------- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/comfy/ldm/rt_detr/rtdetr_v4.py b/comfy/ldm/rt_detr/rtdetr_v4.py index 3233dbdf6..426d3e6c3 100644 --- a/comfy/ldm/rt_detr/rtdetr_v4.py +++ b/comfy/ldm/rt_detr/rtdetr_v4.py @@ -31,7 +31,7 @@ class ConvBNAct(nn.Module): super().__init__() self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype) - self.bn = operations.BatchNorm2d(oc, device=device, dtype=dtype) + self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype) self.act = nn.ReLU() if use_act else nn.Identity() def forward(self, x): diff --git a/comfy/ops.py b/comfy/ops.py index 130dad74e..3752ed395 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -483,34 +483,6 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp): - def reset_parameters(self): - return None - - def forward_comfy_cast_weights(self, input): - if self.weight is not None: - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - else: - weight = None - bias = None - offload_stream = None - - x = torch.nn.functional.batch_norm( - input, - comfy.model_management.cast_to(self.running_mean, dtype=input.dtype, device=input.device), - comfy.model_management.cast_to(self.running_var, dtype=input.dtype, device=input.device), - weight, bias, self.training or not self.track_running_stats, - self.momentum, self.eps - ) - uncast_bias_weight(self, weight, bias, offload_stream) - return x - - def forward(self, *args, **kwargs): - run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): @@ -625,9 +597,6 @@ class manual_cast(disable_weight_init): class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True - class BatchNorm2d(disable_weight_init.BatchNorm2d): - comfy_cast_weights = True - def fp8_linear(self, input): """ From 2c760cbf9820f1995cbcb47faf469fd83df32abf Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:21:36 +0200 Subject: [PATCH 10/11] Update ops.py --- comfy/ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 3752ed395..3e19cd1b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -483,7 +483,6 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None From 55b36bb5fa1fde41546ac1a2591e3d82d645d9ec Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:56:40 +0200 Subject: [PATCH 11/11] Safer --- comfy/ldm/rt_detr/rtdetr_v4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/rt_detr/rtdetr_v4.py b/comfy/ldm/rt_detr/rtdetr_v4.py index 426d3e6c3..9443761cb 100644 --- a/comfy/ldm/rt_detr/rtdetr_v4.py +++ b/comfy/ldm/rt_detr/rtdetr_v4.py @@ -464,7 +464,7 @@ def weighting_function(reg_max, up, reg_scale): step = (ub1 + 1) ** (2 / (reg_max - 2)) left = [-(step ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)] right = [ (step ** i) - 1 for i in range(1, reg_max // 2)] - vals = [-ub2] + left + [torch.zeros_like(up[0][None])] + right + [ub2] + vals = [-ub2] + left + [0] + right + [ub2] return torch.tensor(vals, dtype=up.dtype, device=up.device)