diff --git a/comfy/ldm/rt_detr/rtdetr_v4.py b/comfy/ldm/rt_detr/rtdetr_v4.py index f44a88279..3233dbdf6 100644 --- a/comfy/ldm/rt_detr/rtdetr_v4.py +++ b/comfy/ldm/rt_detr/rtdetr_v4.py @@ -31,7 +31,7 @@ class ConvBNAct(nn.Module): super().__init__() self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype) - self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype) + self.bn = operations.BatchNorm2d(oc, device=device, dtype=dtype) self.act = nn.ReLU() if use_act else nn.Identity() def forward(self, x): @@ -399,7 +399,7 @@ class MSDeformableAttention(nn.Module): attn_w = F.softmax( self.attention_weights(query).reshape( bs, Lq, self.num_heads, sum(self.num_points_list)), -1) - scale = self.num_points_scale.to(query.dtype).unsqueeze(-1) + scale = self.num_points_scale.to(query).unsqueeze(-1) offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2] return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list) @@ -662,12 +662,12 @@ class DFINETransformer(nn.Module): shapes.append([h, w]) return torch.cat(flat, 1), shapes - def _decoder_input(self, memory: torch.Tensor, spatial_shapes): - anchors, valid_mask = self.anchors.to(memory.dtype), self.valid_mask + def _decoder_input(self, memory: torch.Tensor): + anchors, valid_mask = self.anchors.to(memory), self.valid_mask if memory.shape[0] > 1: anchors = anchors.repeat(memory.shape[0], 1, 1) - mem = valid_mask.to(memory.dtype) * memory + mem = valid_mask.to(memory) * memory out_mem = self.enc_output(mem) logits = self.enc_score_head(out_mem) _, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1) @@ -679,7 +679,7 @@ class DFINETransformer(nn.Module): def forward(self, feats: List[torch.Tensor]): memory, shapes = self._encoder_input(feats) - content, ref = self._decoder_input(memory, shapes) + content, ref = self._decoder_input(memory) out_bboxes, out_logits = self.decoder( content, ref, memory, shapes, self.dec_bbox_head, self.dec_score_head, @@ -705,21 +705,21 @@ class RTv4(nn.Module): self.num_classes = num_classes self.num_queries = num_queries + self.load_device = comfy.model_management.get_torch_device() def _forward(self, x: torch.Tensor): return self.decoder(self.encoder(self.backbone(x))) - def postprocess(self, outputs, orig_target_sizes: torch.Tensor): + def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]: logits = outputs['pred_logits'] boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy') - boxes = boxes * orig_target_sizes.repeat(1, 2).unsqueeze(1) + boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1) scores = F.sigmoid(logits) scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1) labels = idx % self.num_classes boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4)) return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)] - def forward(self, x: torch.Tensor, orig_target_sizes: torch.Tensor, **kwargs): - x = comfy.model_management.cast_to_device(x, self.device, self.dtype) - outputs = self._forward(x) - return self.postprocess(outputs, orig_target_sizes) + def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs): + outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype)) + return self.postprocess(outputs, orig_size) diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..09066d9ee 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -483,6 +483,35 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + if self.weight is not None: + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + else: + weight = None + bias = None + offload_stream = None + + x = torch.nn.functional.batch_norm( + input, + comfy.model_management.cast_to(self.running_mean, dtype=input.dtype, device=input.device), + comfy.model_management.cast_to(self.running_var, dtype=input.dtype, device=input.device), + weight, bias, self.training or not self.track_running_stats, + self.momentum, self.eps + ) + uncast_bias_weight(self, weight, bias, offload_stream) + return x + + def forward(self, *args, **kwargs): + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None @@ -596,6 +625,9 @@ class manual_cast(disable_weight_init): class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True + class BatchNorm2d(disable_weight_init.BatchNorm2d): + comfy_cast_weights = True + def fp8_linear(self, input): """ diff --git a/comfy_extras/nodes_rtdetr.py b/comfy_extras/nodes_rtdetr.py index 5e78065f7..60c3c9b92 100644 --- a/comfy_extras/nodes_rtdetr.py +++ b/comfy_extras/nodes_rtdetr.py @@ -34,24 +34,18 @@ class RTDETR_detect(io.ComfyNode): image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled") - device = comfy.model_management.get_torch_device() - dtype = model.model.get_dtype_inference() - orig_size = torch.tensor([[W, H]], device=device, dtype=dtype).expand(B, -1) # [B, 2] as (W, H) - comfy.model_management.load_model_gpu(model) - results = model.model.diffusion_model(image_in.to(device=device, dtype=dtype), orig_size) # list of B dicts + results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts all_bbox_dicts = [] - def _postprocess(results, threshold=0.5): - det = results[0] - keep = det['scores'] > threshold - return det['boxes'][keep].cpu(), det['labels'][keep].cpu(), det['scores'][keep].cpu() + for det in results: + keep = det['scores'] > threshold + boxes = det['boxes'][keep].cpu() + labels = det['labels'][keep].cpu() + scores = det['scores'][keep].cpu() - for i in range(B): - boxes, labels, scores = _postprocess(results[i:i+1], threshold=threshold) - - bbox_dicts = sorted([ + bbox_dicts = [ { "x": float(box[0]), "y": float(box[1]), @@ -62,8 +56,9 @@ class RTDETR_detect(io.ComfyNode): } for box, label, score in zip(boxes, labels, scores) if class_name == "all" or COCO_CLASSES[int(label)] == class_name - ], key=lambda d: d["score"], reverse=True)[:max_detections] - all_bbox_dicts.append(bbox_dicts) + ] + bbox_dicts.sort(key=lambda d: d["score"], reverse=True) + all_bbox_dicts.append(bbox_dicts[:max_detections]) return io.NodeOutput(all_bbox_dicts)