diff --git a/comfy/ldm/rf_detr/rfdetr_v4.py b/comfy/ldm/rf_detr/rfdetr_v4.py index fc76e7852..4793a20ac 100644 --- a/comfy/ldm/rf_detr/rfdetr_v4.py +++ b/comfy/ldm/rf_detr/rfdetr_v4.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision +import comfy.model_management COCO_CLASSES = [ 'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat', @@ -25,35 +26,35 @@ COCO_CLASSES = [ class ConvBNAct(nn.Module): """Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem).""" - def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True): + def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True, device=None, dtype=None, operations=None): super().__init__() - self.conv = nn.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False) - self.bn = nn.BatchNorm2d(oc) + self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype) + self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype) self.act = nn.ReLU() if use_act else nn.Identity() def forward(self, x): return self.act(self.bn(self.conv(x))) class LightConvBNAct(nn.Module): - def __init__(self, ic, oc, k): + def __init__(self, ic, oc, k, device=None, dtype=None, operations=None): super().__init__() - self.conv1 = ConvBNAct(ic, oc, 1, use_act=False) - self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True) + self.conv1 = ConvBNAct(ic, oc, 1, use_act=False, device=device, dtype=dtype, operations=operations) + self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True, device=device, dtype=dtype, operations=operations) def forward(self, x): return self.conv2(self.conv1(x)) class _StemBlock(nn.Module): - def __init__(self, ic, mc, oc): + def __init__(self, ic, mc, oc, device=None, dtype=None, operations=None): super().__init__() - self.stem1 = ConvBNAct(ic, mc, 3, 2) + self.stem1 = ConvBNAct(ic, mc, 3, 2, device=device, dtype=dtype, operations=operations) # stem2a/stem2b use kernel=2, stride=1, no internal padding; # padding is applied manually in forward (matching PaddlePaddle original) - self.stem2a = ConvBNAct(mc, mc//2, 2, 1) - self.stem2b = ConvBNAct(mc//2, mc, 2, 1) - self.stem3 = ConvBNAct(mc*2, mc, 3, 2) - self.stem4 = ConvBNAct(mc, oc, 1) + self.stem2a = ConvBNAct(mc, mc//2, 2, 1, device=device, dtype=dtype, operations=operations) + self.stem2b = ConvBNAct(mc//2, mc, 2, 1, device=device, dtype=dtype, operations=operations) + self.stem3 = ConvBNAct(mc*2, mc, 3, 2, device=device, dtype=dtype, operations=operations) + self.stem4 = ConvBNAct(mc, oc, 1, device=device, dtype=dtype, operations=operations) self.pool = nn.MaxPool2d(2, 1, ceil_mode=True) def forward(self, x): @@ -67,20 +68,20 @@ class _StemBlock(nn.Module): class _HG_Block(nn.Module): - def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False): + def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False, device=None, dtype=None, operations=None): super().__init__() self.residual = residual if light: self.layers = nn.ModuleList( - [LightConvBNAct(ic if i == 0 else mc, mc, k) for i in range(layer_num)]) + [LightConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)]) else: self.layers = nn.ModuleList( - [ConvBNAct(ic if i == 0 else mc, mc, k) for i in range(layer_num)]) + [ConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)]) total = ic + layer_num * mc self.aggregation = nn.Sequential( - ConvBNAct(total, oc // 2, 1), - ConvBNAct(oc // 2, oc, 1)) + ConvBNAct(total, oc // 2, 1, device=device, dtype=dtype, operations=operations), + ConvBNAct(oc // 2, oc, 1, device=device, dtype=dtype, operations=operations)) def forward(self, x): identity = x @@ -94,15 +95,15 @@ class _HG_Block(nn.Module): class _HG_Stage(nn.Module): # config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num - def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6): + def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6, device=None, dtype=None, operations=None): super().__init__() if downsample: - self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False) + self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False, device=device, dtype=dtype, operations=operations) else: self.downsample = nn.Identity() self.blocks = nn.Sequential(*[ _HG_Block(ic if i == 0 else oc, mc, oc, layer_num, - k=k, residual=(i != 0), light=light) + k=k, residual=(i != 0), light=light, device=device, dtype=dtype, operations=operations) for i in range(num_blocks) ]) @@ -117,10 +118,10 @@ class HGNetv2(nn.Module): [512, 256, 1024, 5, True, True, 5, 6], [1024,512, 2048, 2, True, True, 5, 6]] - def __init__(self, return_idx=(1, 2, 3)): + def __init__(self, return_idx=(1, 2, 3), device=None, dtype=None, operations=None): super().__init__() - self.stem = _StemBlock(3, 32, 64) - self.stages = nn.ModuleList([_HG_Stage(*cfg) for cfg in self._STAGE_CFGS]) + self.stem = _StemBlock(3, 32, 64, device=device, dtype=dtype, operations=operations) + self.stages = nn.ModuleList([_HG_Stage(*cfg, device=device, dtype=dtype, operations=operations) for cfg in self._STAGE_CFGS]) self.return_idx = list(return_idx) self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx] @@ -140,10 +141,10 @@ class HGNetv2(nn.Module): class ConvNormLayer(nn.Module): """Conv→act (expects pre-fused BN weights).""" - def __init__(self, ic, oc, k, s, g=1, padding=None, act=None): + def __init__(self, ic, oc, k, s, g=1, padding=None, act=None, device=None, dtype=None, operations=None): super().__init__() p = (k - 1) // 2 if padding is None else padding - self.conv = nn.Conv2d(ic, oc, k, s, p, groups=g, bias=True) + self.conv = operations.Conv2d(ic, oc, k, s, p, groups=g, bias=True, device=device, dtype=dtype) self.act = nn.SiLU() if act == 'silu' else nn.Identity() def forward(self, x): @@ -152,9 +153,9 @@ class ConvNormLayer(nn.Module): class VGGBlock(nn.Module): """Rep-VGG block (expects pre-fused weights).""" - def __init__(self, ic, oc): + def __init__(self, ic, oc, device=None, dtype=None, operations=None): super().__init__() - self.conv = nn.Conv2d(ic, oc, 3, 1, padding=1, bias=True) + self.conv = operations.Conv2d(ic, oc, 3, 1, padding=1, bias=True, device=device, dtype=dtype) self.act = nn.SiLU() def forward(self, x): @@ -162,13 +163,13 @@ class VGGBlock(nn.Module): class CSPLayer(nn.Module): - def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu'): + def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu', device=None, dtype=None, operations=None): super().__init__() h = int(oc * expansion) - self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act) - self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act) - self.bottlenecks = nn.Sequential(*[VGGBlock(h, h) for _ in range(num_blocks)]) - self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act) if h != oc else nn.Identity() + self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations) + self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations) + self.bottlenecks = nn.Sequential(*[VGGBlock(h, h, device=device, dtype=dtype, operations=operations) for _ in range(num_blocks)]) + self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act, device=device, dtype=dtype, operations=operations) if h != oc else nn.Identity() def forward(self, x): return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x)) @@ -176,13 +177,13 @@ class CSPLayer(nn.Module): class RepNCSPELAN4(nn.Module): """CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder.""" - def __init__(self, c1, c2, c3, c4, n=3, act='silu'): + def __init__(self, c1, c2, c3, c4, n=3, act='silu', device=None, dtype=None, operations=None): super().__init__() self.c = c3 // 2 - self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act) - self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act), ConvNormLayer(c4, c4, 3, 1, act=act)) - self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act), ConvNormLayer(c4, c4, 3, 1, act=act)) - self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act) + self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act, device=device, dtype=dtype, operations=operations) + self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations)) + self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations)) + self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act, device=device, dtype=dtype, operations=operations) def forward(self, x): y = list(self.cv1(x).split((self.c, self.c), 1)) @@ -192,10 +193,10 @@ class RepNCSPELAN4(nn.Module): class SCDown(nn.Module): """Separable conv downsampling used in HybridEncoder PAN bottom-up path.""" - def __init__(self, ic, oc, k, s): + def __init__(self, ic, oc, k, s, device=None, dtype=None, operations=None): super().__init__() - self.cv1 = ConvNormLayer(ic, oc, 1, 1) - self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc) + self.cv1 = ConvNormLayer(ic, oc, 1, 1, device=device, dtype=dtype, operations=operations) + self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc, device=device, dtype=dtype, operations=operations) def forward(self, x): return self.cv2(self.cv1(x)) @@ -203,13 +204,13 @@ class SCDown(nn.Module): class _TransformerEncoderLayer(nn.Module): """Single AIFI encoder layer (pre- or post-norm, GELU by default).""" - def __init__(self, d_model, nhead, dim_feedforward): + def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.linear2 = nn.Linear(dim_feedforward, d_model) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype) + self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) self.activation = nn.GELU() def forward(self, src, src_mask=None, pos_embed=None): @@ -222,10 +223,10 @@ class _TransformerEncoderLayer(nn.Module): class _TransformerEncoder(nn.Module): """Thin wrapper so state-dict keys are encoder.0.layers.N.*""" - def __init__(self, num_layers, d_model, nhead, dim_feedforward): + def __init__(self, num_layers, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None): super().__init__() self.layers = nn.ModuleList([ - _TransformerEncoderLayer(d_model, nhead, dim_feedforward) + _TransformerEncoderLayer(d_model, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) ]) @@ -237,7 +238,7 @@ class _TransformerEncoder(nn.Module): class HybridEncoder(nn.Module): def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1, - pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640)): + pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640), device=None, dtype=None, operations=None): super().__init__() self.in_channels = list(in_channels) self.feat_strides = list(feat_strides) @@ -250,13 +251,13 @@ class HybridEncoder(nn.Module): # channel projection (expects pre-fused weights) self.input_proj = nn.ModuleList([ - nn.Sequential(OrderedDict([('conv', nn.Conv2d(ch, hidden_dim, 1, bias=True))])) + nn.Sequential(OrderedDict([('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))])) for ch in in_channels ]) # AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.* self.encoder = nn.ModuleList([ - _TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward) + _TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations) for _ in range(len(use_encoder_idx)) ]) @@ -265,18 +266,18 @@ class HybridEncoder(nn.Module): # top-down FPN (dfine: lateral conv has no act) self.lateral_convs = nn.ModuleList( - [ConvNormLayer(hidden_dim, hidden_dim, 1, 1) + [ConvNormLayer(hidden_dim, hidden_dim, 1, 1, device=device, dtype=dtype, operations=operations) for _ in range(len(in_channels) - 1)]) self.fpn_blocks = nn.ModuleList( - [RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act) + [RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations) for _ in range(len(in_channels) - 1)]) # bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2) self.downsample_convs = nn.ModuleList( - [nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2)) + [nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, device=device, dtype=dtype, operations=operations)) for _ in range(len(in_channels) - 1)]) self.pan_blocks = nn.ModuleList( - [RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act) + [RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations) for _ in range(len(in_channels) - 1)]) # cache positional embeddings for fixed spatial size @@ -360,7 +361,7 @@ def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.T class MSDeformableAttention(nn.Module): - def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5): + def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5, device=None, dtype=None, operations=None): super().__init__() self.embed_dim, self.num_heads = embed_dim, num_heads self.head_dim = embed_dim // num_heads @@ -369,8 +370,8 @@ class MSDeformableAttention(nn.Module): self.offset_scale = offset_scale total = num_heads * sum(pts) self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32)) - self.sampling_offsets = nn.Linear(embed_dim, total * 2) - self.attention_weights = nn.Linear(embed_dim, total) + self.sampling_offsets = operations.Linear(embed_dim, total * 2, device=device, dtype=dtype) + self.attention_weights = operations.Linear(embed_dim, total, device=device, dtype=dtype) def forward(self, query, ref_pts, value, spatial_shapes): bs, Lq = query.shape[:2] @@ -386,10 +387,10 @@ class MSDeformableAttention(nn.Module): class Gate(nn.Module): - def __init__(self, d_model): + def __init__(self, d_model, device=None, dtype=None, operations=None): super().__init__() - self.gate = nn.Linear(2 * d_model, 2 * d_model) - self.norm = nn.LayerNorm(d_model) + self.gate = operations.Linear(2 * d_model, 2 * d_model, device=device, dtype=dtype) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) def forward(self, x1, x2): g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1) @@ -397,10 +398,10 @@ class Gate(nn.Module): class MLP(nn.Module): - def __init__(self, in_dim, hidden_dim, out_dim, num_layers): + def __init__(self, in_dim, hidden_dim, out_dim, num_layers, device=None, dtype=None, operations=None): super().__init__() dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim] - self.layers = nn.ModuleList(nn.Linear(dims[i], dims[i + 1]) for i in range(num_layers)) + self.layers = nn.ModuleList(operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)) def forward(self, x): for i, layer in enumerate(self.layers): @@ -409,16 +410,16 @@ class MLP(nn.Module): class TransformerDecoderLayer(nn.Module): - def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4): + def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) - self.norm1 = nn.LayerNorm(d_model) - self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points) - self.gateway = Gate(d_model) - self.linear1 = nn.Linear(d_model, dim_feedforward) + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations) + self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations) + self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype) self.activation = nn.ReLU() - self.linear2 = nn.Linear(dim_feedforward, d_model) - self.norm3 = nn.LayerNorm(d_model) + self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None): q = k = target if query_pos is None else target + query_pos @@ -474,10 +475,10 @@ class Integral(nn.Module): class LQE(nn.Module): """Location Quality Estimator — refines class scores using corner distribution.""" - def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32): + def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32, device=None, dtype=None, operations=None): super().__init__() self.k, self.reg_max = k, reg_max - self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers) + self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers, device=device, dtype=dtype, operations=operations) def forward(self, scores, pred_corners): B, L, _ = pred_corners.shape @@ -488,8 +489,7 @@ class LQE(nn.Module): class TransformerDecoder(nn.Module): - def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, - num_layers, reg_max, reg_scale, up, eval_idx=-1): + def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, num_layers, reg_max, reg_scale, up, eval_idx=-1, device=None, dtype=None, operations=None): super().__init__() self.hidden_dim = hidden_dim self.num_layers = num_layers @@ -497,10 +497,10 @@ class TransformerDecoder(nn.Module): self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max self.layers = nn.ModuleList([ - TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points) + TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1) ]) - self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max) for _ in range(self.eval_idx + 1)]) + self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1)]) self.register_buffer('project', weighting_function(reg_max, up, reg_scale)) def _value_op(self, memory, spatial_shapes): @@ -557,7 +557,8 @@ class TransformerDecoder(nn.Module): class DFINETransformer(nn.Module): def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32], - num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32, reg_scale=8.0, eval_spatial_size=(640, 640)): + num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32, + reg_scale=8.0, eval_spatial_size=(640, 640), device=None, dtype=None, operations=None): super().__init__() assert len(feat_strides) == len(feat_channels) self.hidden_dim = hidden_dim @@ -577,12 +578,12 @@ class DFINETransformer(nn.Module): self.input_proj.append(nn.Identity()) else: self.input_proj.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(ch, hidden_dim, 1, bias=True))]))) + ('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))]))) in_ch = feat_channels[-1] for i in range(num_levels - len(feat_channels)): self.input_proj.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(in_ch if i == 0 else hidden_dim, - hidden_dim, 3, 2, 1, bias=True))]))) + ('conv', operations.Conv2d(in_ch if i == 0 else hidden_dim, + hidden_dim, 3, 2, 1, bias=True, device=device, dtype=dtype))]))) in_ch = hidden_dim # FDR parameters (non-trainable placeholders, set from config) @@ -591,21 +592,21 @@ class DFINETransformer(nn.Module): pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts, - num_layers, reg_max, self.reg_scale, self.up, eval_idx) + num_layers, reg_max, self.reg_scale, self.up, eval_idx, device=device, dtype=dtype, operations=operations) - self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2) + self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2, device=device, dtype=dtype, operations=operations) self.enc_output = nn.Sequential(OrderedDict([ - ('proj', nn.Linear(hidden_dim, hidden_dim)), - ('norm', nn.LayerNorm(hidden_dim))])) - self.enc_score_head = nn.Linear(hidden_dim, num_classes) - self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) + ('proj', operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)), + ('norm', operations.LayerNorm(hidden_dim, device=device, dtype=dtype))])) + self.enc_score_head = operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) + self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations) self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx self.dec_score_head = nn.ModuleList( - [nn.Linear(hidden_dim, num_classes) for _ in range(self.eval_idx_ + 1)]) - self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) + [operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) for _ in range(self.eval_idx_ + 1)]) + self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations) self.dec_bbox_head = nn.ModuleList( - [MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3) + [MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx_ + 1)]) self.integral = Integral(reg_max) @@ -672,21 +673,21 @@ class DFINETransformer(nn.Module): # --------------------------------------------------------------------------- class RTv4(nn.Module): - def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, num_heads=8, feat_strides=[8, 16, 32], num_head_channels=None, - image_model=None, device=None, dtype=None, operations=None): + def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, feat_strides=[8, 16, 32], device=None, dtype=None, operations=None, **kwargs): super().__init__() + self.device = device self.dtype = dtype self.operations = operations - self.backbone = HGNetv2() - self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff) + self.backbone = HGNetv2(device=device, dtype=dtype, operations=operations) + self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff, device=device, dtype=dtype, operations=operations) self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries, - feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff) + feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff, device=device, dtype=dtype, operations=operations) self.num_classes = num_classes self.num_queries = num_queries - def forward(self, x: torch.Tensor): + def _forward(self, x: torch.Tensor): return self.decoder(self.encoder(self.backbone(x))) def postprocess(self, outputs, orig_target_sizes: torch.Tensor): @@ -698,3 +699,8 @@ class RTv4(nn.Module): labels = idx % self.num_classes boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4)) return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)] + + def forward(self, x: torch.Tensor, orig_target_sizes: torch.Tensor, **kwargs): + x = comfy.model_management.cast_to_device(x, self.device, self.dtype) + outputs = self._forward(x) + return self.postprocess(outputs, orig_target_sizes) diff --git a/comfy_extras/nodes_rfdetr.py b/comfy_extras/nodes_rfdetr.py index 321899bdd..513f36d5e 100644 --- a/comfy_extras/nodes_rfdetr.py +++ b/comfy_extras/nodes_rfdetr.py @@ -29,14 +29,14 @@ class RFDETR_detect(io.ComfyNode): def execute(cls, model, image, threshold, class_name) -> io.NodeOutput: B, H, W, C = image.shape - device = comfy.model_management.get_torch_device() - orig_size = torch.tensor([[W, H]], device=device, dtype=torch.float32).expand(B, -1) # [B, 2] as (W, H) - image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled") + device = comfy.model_management.get_torch_device() + dtype = model.model.get_dtype_inference() + orig_size = torch.tensor([[W, H]], device=device, dtype=dtype).expand(B, -1) # [B, 2] as (W, H) + comfy.model_management.load_model_gpu(model) - out = model.model.diffusion_model(image_in.to(device=device)) # [B, num_queries, 4+num_classes] - results = model.model.diffusion_model.postprocess(out, orig_size) # list of B dicts + results = model.model.diffusion_model(image_in.to(device=device, dtype=dtype), orig_size) # list of B dicts all_bbox_dicts = []