Add BatchNorm2d to ops, support dynamic vram

This commit is contained in:
kijai 2026-03-04 00:33:48 +02:00
parent 8377f68c29
commit a888b90f06
3 changed files with 54 additions and 27 deletions

View File

@ -31,7 +31,7 @@ class ConvBNAct(nn.Module):
super().__init__()
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
self.bn = operations.BatchNorm2d(oc, device=device, dtype=dtype)
self.act = nn.ReLU() if use_act else nn.Identity()
def forward(self, x):
@ -399,7 +399,7 @@ class MSDeformableAttention(nn.Module):
attn_w = F.softmax(
self.attention_weights(query).reshape(
bs, Lq, self.num_heads, sum(self.num_points_list)), -1)
scale = self.num_points_scale.to(query.dtype).unsqueeze(-1)
scale = self.num_points_scale.to(query).unsqueeze(-1)
offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale
locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2]
return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list)
@ -662,12 +662,12 @@ class DFINETransformer(nn.Module):
shapes.append([h, w])
return torch.cat(flat, 1), shapes
def _decoder_input(self, memory: torch.Tensor, spatial_shapes):
anchors, valid_mask = self.anchors.to(memory.dtype), self.valid_mask
def _decoder_input(self, memory: torch.Tensor):
anchors, valid_mask = self.anchors.to(memory), self.valid_mask
if memory.shape[0] > 1:
anchors = anchors.repeat(memory.shape[0], 1, 1)
mem = valid_mask.to(memory.dtype) * memory
mem = valid_mask.to(memory) * memory
out_mem = self.enc_output(mem)
logits = self.enc_score_head(out_mem)
_, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1)
@ -679,7 +679,7 @@ class DFINETransformer(nn.Module):
def forward(self, feats: List[torch.Tensor]):
memory, shapes = self._encoder_input(feats)
content, ref = self._decoder_input(memory, shapes)
content, ref = self._decoder_input(memory)
out_bboxes, out_logits = self.decoder(
content, ref, memory, shapes,
self.dec_bbox_head, self.dec_score_head,
@ -705,21 +705,21 @@ class RTv4(nn.Module):
self.num_classes = num_classes
self.num_queries = num_queries
self.load_device = comfy.model_management.get_torch_device()
def _forward(self, x: torch.Tensor):
return self.decoder(self.encoder(self.backbone(x)))
def postprocess(self, outputs, orig_target_sizes: torch.Tensor):
def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]:
logits = outputs['pred_logits']
boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy')
boxes = boxes * orig_target_sizes.repeat(1, 2).unsqueeze(1)
boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1)
scores = F.sigmoid(logits)
scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1)
labels = idx % self.num_classes
boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4))
return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)]
def forward(self, x: torch.Tensor, orig_target_sizes: torch.Tensor, **kwargs):
x = comfy.model_management.cast_to_device(x, self.device, self.dtype)
outputs = self._forward(x)
return self.postprocess(outputs, orig_target_sizes)
def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs):
outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype))
return self.postprocess(outputs, orig_size)

View File

@ -483,6 +483,35 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
bias = None
offload_stream = None
x = torch.nn.functional.batch_norm(
input,
comfy.model_management.cast_to(self.running_mean, dtype=input.dtype, device=input.device),
comfy.model_management.cast_to(self.running_var, dtype=input.dtype, device=input.device),
weight, bias, self.training or not self.track_running_stats,
self.momentum, self.eps
)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -596,6 +625,9 @@ class manual_cast(disable_weight_init):
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
class BatchNorm2d(disable_weight_init.BatchNorm2d):
comfy_cast_weights = True
def fp8_linear(self, input):
"""

View File

@ -34,24 +34,18 @@ class RTDETR_detect(io.ComfyNode):
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
device = comfy.model_management.get_torch_device()
dtype = model.model.get_dtype_inference()
orig_size = torch.tensor([[W, H]], device=device, dtype=dtype).expand(B, -1) # [B, 2] as (W, H)
comfy.model_management.load_model_gpu(model)
results = model.model.diffusion_model(image_in.to(device=device, dtype=dtype), orig_size) # list of B dicts
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
all_bbox_dicts = []
def _postprocess(results, threshold=0.5):
det = results[0]
keep = det['scores'] > threshold
return det['boxes'][keep].cpu(), det['labels'][keep].cpu(), det['scores'][keep].cpu()
for det in results:
keep = det['scores'] > threshold
boxes = det['boxes'][keep].cpu()
labels = det['labels'][keep].cpu()
scores = det['scores'][keep].cpu()
for i in range(B):
boxes, labels, scores = _postprocess(results[i:i+1], threshold=threshold)
bbox_dicts = sorted([
bbox_dicts = [
{
"x": float(box[0]),
"y": float(box[1]),
@ -62,8 +56,9 @@ class RTDETR_detect(io.ComfyNode):
}
for box, label, score in zip(boxes, labels, scores)
if class_name == "all" or COCO_CLASSES[int(label)] == class_name
], key=lambda d: d["score"], reverse=True)[:max_detections]
all_bbox_dicts.append(bbox_dicts)
]
bbox_dicts.sort(key=lambda d: d["score"], reverse=True)
all_bbox_dicts.append(bbox_dicts[:max_detections])
return io.NodeOutput(all_bbox_dicts)