Use comfy ops

This commit is contained in:
kijai 2026-02-27 18:00:34 +02:00
parent f18b293910
commit 3929172408
2 changed files with 107 additions and 101 deletions

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import comfy.model_management
COCO_CLASSES = [
'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat',
@ -25,35 +26,35 @@ COCO_CLASSES = [
class ConvBNAct(nn.Module):
"""Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem)."""
def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True):
def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True, device=None, dtype=None, operations=None):
super().__init__()
self.conv = nn.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False)
self.bn = nn.BatchNorm2d(oc)
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
self.act = nn.ReLU() if use_act else nn.Identity()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class LightConvBNAct(nn.Module):
def __init__(self, ic, oc, k):
def __init__(self, ic, oc, k, device=None, dtype=None, operations=None):
super().__init__()
self.conv1 = ConvBNAct(ic, oc, 1, use_act=False)
self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True)
self.conv1 = ConvBNAct(ic, oc, 1, use_act=False, device=device, dtype=dtype, operations=operations)
self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True, device=device, dtype=dtype, operations=operations)
def forward(self, x):
return self.conv2(self.conv1(x))
class _StemBlock(nn.Module):
def __init__(self, ic, mc, oc):
def __init__(self, ic, mc, oc, device=None, dtype=None, operations=None):
super().__init__()
self.stem1 = ConvBNAct(ic, mc, 3, 2)
self.stem1 = ConvBNAct(ic, mc, 3, 2, device=device, dtype=dtype, operations=operations)
# stem2a/stem2b use kernel=2, stride=1, no internal padding;
# padding is applied manually in forward (matching PaddlePaddle original)
self.stem2a = ConvBNAct(mc, mc//2, 2, 1)
self.stem2b = ConvBNAct(mc//2, mc, 2, 1)
self.stem3 = ConvBNAct(mc*2, mc, 3, 2)
self.stem4 = ConvBNAct(mc, oc, 1)
self.stem2a = ConvBNAct(mc, mc//2, 2, 1, device=device, dtype=dtype, operations=operations)
self.stem2b = ConvBNAct(mc//2, mc, 2, 1, device=device, dtype=dtype, operations=operations)
self.stem3 = ConvBNAct(mc*2, mc, 3, 2, device=device, dtype=dtype, operations=operations)
self.stem4 = ConvBNAct(mc, oc, 1, device=device, dtype=dtype, operations=operations)
self.pool = nn.MaxPool2d(2, 1, ceil_mode=True)
def forward(self, x):
@ -67,20 +68,20 @@ class _StemBlock(nn.Module):
class _HG_Block(nn.Module):
def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False):
def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False, device=None, dtype=None, operations=None):
super().__init__()
self.residual = residual
if light:
self.layers = nn.ModuleList(
[LightConvBNAct(ic if i == 0 else mc, mc, k) for i in range(layer_num)])
[LightConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
else:
self.layers = nn.ModuleList(
[ConvBNAct(ic if i == 0 else mc, mc, k) for i in range(layer_num)])
[ConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
total = ic + layer_num * mc
self.aggregation = nn.Sequential(
ConvBNAct(total, oc // 2, 1),
ConvBNAct(oc // 2, oc, 1))
ConvBNAct(total, oc // 2, 1, device=device, dtype=dtype, operations=operations),
ConvBNAct(oc // 2, oc, 1, device=device, dtype=dtype, operations=operations))
def forward(self, x):
identity = x
@ -94,15 +95,15 @@ class _HG_Block(nn.Module):
class _HG_Stage(nn.Module):
# config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num
def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6):
def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6, device=None, dtype=None, operations=None):
super().__init__()
if downsample:
self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False)
self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False, device=device, dtype=dtype, operations=operations)
else:
self.downsample = nn.Identity()
self.blocks = nn.Sequential(*[
_HG_Block(ic if i == 0 else oc, mc, oc, layer_num,
k=k, residual=(i != 0), light=light)
k=k, residual=(i != 0), light=light, device=device, dtype=dtype, operations=operations)
for i in range(num_blocks)
])
@ -117,10 +118,10 @@ class HGNetv2(nn.Module):
[512, 256, 1024, 5, True, True, 5, 6],
[1024,512, 2048, 2, True, True, 5, 6]]
def __init__(self, return_idx=(1, 2, 3)):
def __init__(self, return_idx=(1, 2, 3), device=None, dtype=None, operations=None):
super().__init__()
self.stem = _StemBlock(3, 32, 64)
self.stages = nn.ModuleList([_HG_Stage(*cfg) for cfg in self._STAGE_CFGS])
self.stem = _StemBlock(3, 32, 64, device=device, dtype=dtype, operations=operations)
self.stages = nn.ModuleList([_HG_Stage(*cfg, device=device, dtype=dtype, operations=operations) for cfg in self._STAGE_CFGS])
self.return_idx = list(return_idx)
self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx]
@ -140,10 +141,10 @@ class HGNetv2(nn.Module):
class ConvNormLayer(nn.Module):
"""Conv→act (expects pre-fused BN weights)."""
def __init__(self, ic, oc, k, s, g=1, padding=None, act=None):
def __init__(self, ic, oc, k, s, g=1, padding=None, act=None, device=None, dtype=None, operations=None):
super().__init__()
p = (k - 1) // 2 if padding is None else padding
self.conv = nn.Conv2d(ic, oc, k, s, p, groups=g, bias=True)
self.conv = operations.Conv2d(ic, oc, k, s, p, groups=g, bias=True, device=device, dtype=dtype)
self.act = nn.SiLU() if act == 'silu' else nn.Identity()
def forward(self, x):
@ -152,9 +153,9 @@ class ConvNormLayer(nn.Module):
class VGGBlock(nn.Module):
"""Rep-VGG block (expects pre-fused weights)."""
def __init__(self, ic, oc):
def __init__(self, ic, oc, device=None, dtype=None, operations=None):
super().__init__()
self.conv = nn.Conv2d(ic, oc, 3, 1, padding=1, bias=True)
self.conv = operations.Conv2d(ic, oc, 3, 1, padding=1, bias=True, device=device, dtype=dtype)
self.act = nn.SiLU()
def forward(self, x):
@ -162,13 +163,13 @@ class VGGBlock(nn.Module):
class CSPLayer(nn.Module):
def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu'):
def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu', device=None, dtype=None, operations=None):
super().__init__()
h = int(oc * expansion)
self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act)
self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act)
self.bottlenecks = nn.Sequential(*[VGGBlock(h, h) for _ in range(num_blocks)])
self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act) if h != oc else nn.Identity()
self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.bottlenecks = nn.Sequential(*[VGGBlock(h, h, device=device, dtype=dtype, operations=operations) for _ in range(num_blocks)])
self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act, device=device, dtype=dtype, operations=operations) if h != oc else nn.Identity()
def forward(self, x):
return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x))
@ -176,13 +177,13 @@ class CSPLayer(nn.Module):
class RepNCSPELAN4(nn.Module):
"""CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder."""
def __init__(self, c1, c2, c3, c4, n=3, act='silu'):
def __init__(self, c1, c2, c3, c4, n=3, act='silu', device=None, dtype=None, operations=None):
super().__init__()
self.c = c3 // 2
self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act)
self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act), ConvNormLayer(c4, c4, 3, 1, act=act))
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act), ConvNormLayer(c4, c4, 3, 1, act=act))
self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act)
self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
def forward(self, x):
y = list(self.cv1(x).split((self.c, self.c), 1))
@ -192,10 +193,10 @@ class RepNCSPELAN4(nn.Module):
class SCDown(nn.Module):
"""Separable conv downsampling used in HybridEncoder PAN bottom-up path."""
def __init__(self, ic, oc, k, s):
def __init__(self, ic, oc, k, s, device=None, dtype=None, operations=None):
super().__init__()
self.cv1 = ConvNormLayer(ic, oc, 1, 1)
self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc)
self.cv1 = ConvNormLayer(ic, oc, 1, 1, device=device, dtype=dtype, operations=operations)
self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc, device=device, dtype=dtype, operations=operations)
def forward(self, x):
return self.cv2(self.cv1(x))
@ -203,13 +204,13 @@ class SCDown(nn.Module):
class _TransformerEncoderLayer(nn.Module):
"""Single AIFI encoder layer (pre- or post-norm, GELU by default)."""
def __init__(self, d_model, nhead, dim_feedforward):
def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype)
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.activation = nn.GELU()
def forward(self, src, src_mask=None, pos_embed=None):
@ -222,10 +223,10 @@ class _TransformerEncoderLayer(nn.Module):
class _TransformerEncoder(nn.Module):
"""Thin wrapper so state-dict keys are encoder.0.layers.N.*"""
def __init__(self, num_layers, d_model, nhead, dim_feedforward):
def __init__(self, num_layers, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
_TransformerEncoderLayer(d_model, nhead, dim_feedforward)
_TransformerEncoderLayer(d_model, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
@ -237,7 +238,7 @@ class _TransformerEncoder(nn.Module):
class HybridEncoder(nn.Module):
def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1,
pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640)):
pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
super().__init__()
self.in_channels = list(in_channels)
self.feat_strides = list(feat_strides)
@ -250,13 +251,13 @@ class HybridEncoder(nn.Module):
# channel projection (expects pre-fused weights)
self.input_proj = nn.ModuleList([
nn.Sequential(OrderedDict([('conv', nn.Conv2d(ch, hidden_dim, 1, bias=True))]))
nn.Sequential(OrderedDict([('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))]))
for ch in in_channels
])
# AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.*
self.encoder = nn.ModuleList([
_TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward)
_TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
for _ in range(len(use_encoder_idx))
])
@ -265,18 +266,18 @@ class HybridEncoder(nn.Module):
# top-down FPN (dfine: lateral conv has no act)
self.lateral_convs = nn.ModuleList(
[ConvNormLayer(hidden_dim, hidden_dim, 1, 1)
[ConvNormLayer(hidden_dim, hidden_dim, 1, 1, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
self.fpn_blocks = nn.ModuleList(
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act)
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
# bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2)
self.downsample_convs = nn.ModuleList(
[nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2))
[nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, device=device, dtype=dtype, operations=operations))
for _ in range(len(in_channels) - 1)])
self.pan_blocks = nn.ModuleList(
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act)
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
# cache positional embeddings for fixed spatial size
@ -360,7 +361,7 @@ def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.T
class MSDeformableAttention(nn.Module):
def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5):
def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5, device=None, dtype=None, operations=None):
super().__init__()
self.embed_dim, self.num_heads = embed_dim, num_heads
self.head_dim = embed_dim // num_heads
@ -369,8 +370,8 @@ class MSDeformableAttention(nn.Module):
self.offset_scale = offset_scale
total = num_heads * sum(pts)
self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32))
self.sampling_offsets = nn.Linear(embed_dim, total * 2)
self.attention_weights = nn.Linear(embed_dim, total)
self.sampling_offsets = operations.Linear(embed_dim, total * 2, device=device, dtype=dtype)
self.attention_weights = operations.Linear(embed_dim, total, device=device, dtype=dtype)
def forward(self, query, ref_pts, value, spatial_shapes):
bs, Lq = query.shape[:2]
@ -386,10 +387,10 @@ class MSDeformableAttention(nn.Module):
class Gate(nn.Module):
def __init__(self, d_model):
def __init__(self, d_model, device=None, dtype=None, operations=None):
super().__init__()
self.gate = nn.Linear(2 * d_model, 2 * d_model)
self.norm = nn.LayerNorm(d_model)
self.gate = operations.Linear(2 * d_model, 2 * d_model, device=device, dtype=dtype)
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, x1, x2):
g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1)
@ -397,10 +398,10 @@ class Gate(nn.Module):
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers, device=None, dtype=None, operations=None):
super().__init__()
dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
self.layers = nn.ModuleList(nn.Linear(dims[i], dims[i + 1]) for i in range(num_layers))
self.layers = nn.ModuleList(operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers))
def forward(self, x):
for i, layer in enumerate(self.layers):
@ -409,16 +410,16 @@ class MLP(nn.Module):
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4):
def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points)
self.gateway = Gate(d_model)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations)
self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
self.activation = nn.ReLU()
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm3 = nn.LayerNorm(d_model)
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None):
q = k = target if query_pos is None else target + query_pos
@ -474,10 +475,10 @@ class Integral(nn.Module):
class LQE(nn.Module):
"""Location Quality Estimator — refines class scores using corner distribution."""
def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32):
def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32, device=None, dtype=None, operations=None):
super().__init__()
self.k, self.reg_max = k, reg_max
self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers)
self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers, device=device, dtype=dtype, operations=operations)
def forward(self, scores, pred_corners):
B, L, _ = pred_corners.shape
@ -488,8 +489,7 @@ class LQE(nn.Module):
class TransformerDecoder(nn.Module):
def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points,
num_layers, reg_max, reg_scale, up, eval_idx=-1):
def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, num_layers, reg_max, reg_scale, up, eval_idx=-1, device=None, dtype=None, operations=None):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
@ -497,10 +497,10 @@ class TransformerDecoder(nn.Module):
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max
self.layers = nn.ModuleList([
TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points)
TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points, device=device, dtype=dtype, operations=operations)
for _ in range(self.eval_idx + 1)
])
self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max) for _ in range(self.eval_idx + 1)])
self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1)])
self.register_buffer('project', weighting_function(reg_max, up, reg_scale))
def _value_op(self, memory, spatial_shapes):
@ -557,7 +557,8 @@ class TransformerDecoder(nn.Module):
class DFINETransformer(nn.Module):
def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32],
num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32, reg_scale=8.0, eval_spatial_size=(640, 640)):
num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32,
reg_scale=8.0, eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
super().__init__()
assert len(feat_strides) == len(feat_channels)
self.hidden_dim = hidden_dim
@ -577,12 +578,12 @@ class DFINETransformer(nn.Module):
self.input_proj.append(nn.Identity())
else:
self.input_proj.append(nn.Sequential(OrderedDict([
('conv', nn.Conv2d(ch, hidden_dim, 1, bias=True))])))
('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))])))
in_ch = feat_channels[-1]
for i in range(num_levels - len(feat_channels)):
self.input_proj.append(nn.Sequential(OrderedDict([
('conv', nn.Conv2d(in_ch if i == 0 else hidden_dim,
hidden_dim, 3, 2, 1, bias=True))])))
('conv', operations.Conv2d(in_ch if i == 0 else hidden_dim,
hidden_dim, 3, 2, 1, bias=True, device=device, dtype=dtype))])))
in_ch = hidden_dim
# FDR parameters (non-trainable placeholders, set from config)
@ -591,21 +592,21 @@ class DFINETransformer(nn.Module):
pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels
self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts,
num_layers, reg_max, self.reg_scale, self.up, eval_idx)
num_layers, reg_max, self.reg_scale, self.up, eval_idx, device=device, dtype=dtype, operations=operations)
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2)
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2, device=device, dtype=dtype, operations=operations)
self.enc_output = nn.Sequential(OrderedDict([
('proj', nn.Linear(hidden_dim, hidden_dim)),
('norm', nn.LayerNorm(hidden_dim))]))
self.enc_score_head = nn.Linear(hidden_dim, num_classes)
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3)
('proj', operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)),
('norm', operations.LayerNorm(hidden_dim, device=device, dtype=dtype))]))
self.enc_score_head = operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype)
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx
self.dec_score_head = nn.ModuleList(
[nn.Linear(hidden_dim, num_classes) for _ in range(self.eval_idx_ + 1)])
self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3)
[operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) for _ in range(self.eval_idx_ + 1)])
self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
self.dec_bbox_head = nn.ModuleList(
[MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3)
[MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3, device=device, dtype=dtype, operations=operations)
for _ in range(self.eval_idx_ + 1)])
self.integral = Integral(reg_max)
@ -672,21 +673,21 @@ class DFINETransformer(nn.Module):
# ---------------------------------------------------------------------------
class RTv4(nn.Module):
def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, num_heads=8, feat_strides=[8, 16, 32], num_head_channels=None,
image_model=None, device=None, dtype=None, operations=None):
def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, feat_strides=[8, 16, 32], device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.device = device
self.dtype = dtype
self.operations = operations
self.backbone = HGNetv2()
self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff)
self.backbone = HGNetv2(device=device, dtype=dtype, operations=operations)
self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff, device=device, dtype=dtype, operations=operations)
self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries,
feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff)
feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff, device=device, dtype=dtype, operations=operations)
self.num_classes = num_classes
self.num_queries = num_queries
def forward(self, x: torch.Tensor):
def _forward(self, x: torch.Tensor):
return self.decoder(self.encoder(self.backbone(x)))
def postprocess(self, outputs, orig_target_sizes: torch.Tensor):
@ -698,3 +699,8 @@ class RTv4(nn.Module):
labels = idx % self.num_classes
boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4))
return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)]
def forward(self, x: torch.Tensor, orig_target_sizes: torch.Tensor, **kwargs):
x = comfy.model_management.cast_to_device(x, self.device, self.dtype)
outputs = self._forward(x)
return self.postprocess(outputs, orig_target_sizes)

View File

@ -29,14 +29,14 @@ class RFDETR_detect(io.ComfyNode):
def execute(cls, model, image, threshold, class_name) -> io.NodeOutput:
B, H, W, C = image.shape
device = comfy.model_management.get_torch_device()
orig_size = torch.tensor([[W, H]], device=device, dtype=torch.float32).expand(B, -1) # [B, 2] as (W, H)
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
device = comfy.model_management.get_torch_device()
dtype = model.model.get_dtype_inference()
orig_size = torch.tensor([[W, H]], device=device, dtype=dtype).expand(B, -1) # [B, 2] as (W, H)
comfy.model_management.load_model_gpu(model)
out = model.model.diffusion_model(image_in.to(device=device)) # [B, num_queries, 4+num_classes]
results = model.model.diffusion_model.postprocess(out, orig_size) # list of B dicts
results = model.model.diffusion_model(image_in.to(device=device, dtype=dtype), orig_size) # list of B dicts
all_bbox_dicts = []