mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-18 15:44:55 +08:00
Add BatchNorm2d to ops, support dynamic vram
This commit is contained in:
parent
8377f68c29
commit
a888b90f06
@ -31,7 +31,7 @@ class ConvBNAct(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
|
||||
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
|
||||
self.bn = operations.BatchNorm2d(oc, device=device, dtype=dtype)
|
||||
self.act = nn.ReLU() if use_act else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
@ -399,7 +399,7 @@ class MSDeformableAttention(nn.Module):
|
||||
attn_w = F.softmax(
|
||||
self.attention_weights(query).reshape(
|
||||
bs, Lq, self.num_heads, sum(self.num_points_list)), -1)
|
||||
scale = self.num_points_scale.to(query.dtype).unsqueeze(-1)
|
||||
scale = self.num_points_scale.to(query).unsqueeze(-1)
|
||||
offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale
|
||||
locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2]
|
||||
return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list)
|
||||
@ -662,12 +662,12 @@ class DFINETransformer(nn.Module):
|
||||
shapes.append([h, w])
|
||||
return torch.cat(flat, 1), shapes
|
||||
|
||||
def _decoder_input(self, memory: torch.Tensor, spatial_shapes):
|
||||
anchors, valid_mask = self.anchors.to(memory.dtype), self.valid_mask
|
||||
def _decoder_input(self, memory: torch.Tensor):
|
||||
anchors, valid_mask = self.anchors.to(memory), self.valid_mask
|
||||
if memory.shape[0] > 1:
|
||||
anchors = anchors.repeat(memory.shape[0], 1, 1)
|
||||
|
||||
mem = valid_mask.to(memory.dtype) * memory
|
||||
mem = valid_mask.to(memory) * memory
|
||||
out_mem = self.enc_output(mem)
|
||||
logits = self.enc_score_head(out_mem)
|
||||
_, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1)
|
||||
@ -679,7 +679,7 @@ class DFINETransformer(nn.Module):
|
||||
|
||||
def forward(self, feats: List[torch.Tensor]):
|
||||
memory, shapes = self._encoder_input(feats)
|
||||
content, ref = self._decoder_input(memory, shapes)
|
||||
content, ref = self._decoder_input(memory)
|
||||
out_bboxes, out_logits = self.decoder(
|
||||
content, ref, memory, shapes,
|
||||
self.dec_bbox_head, self.dec_score_head,
|
||||
@ -705,21 +705,21 @@ class RTv4(nn.Module):
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_queries = num_queries
|
||||
self.load_device = comfy.model_management.get_torch_device()
|
||||
|
||||
def _forward(self, x: torch.Tensor):
|
||||
return self.decoder(self.encoder(self.backbone(x)))
|
||||
|
||||
def postprocess(self, outputs, orig_target_sizes: torch.Tensor):
|
||||
def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]:
|
||||
logits = outputs['pred_logits']
|
||||
boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy')
|
||||
boxes = boxes * orig_target_sizes.repeat(1, 2).unsqueeze(1)
|
||||
boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1)
|
||||
scores = F.sigmoid(logits)
|
||||
scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1)
|
||||
labels = idx % self.num_classes
|
||||
boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4))
|
||||
return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)]
|
||||
|
||||
def forward(self, x: torch.Tensor, orig_target_sizes: torch.Tensor, **kwargs):
|
||||
x = comfy.model_management.cast_to_device(x, self.device, self.dtype)
|
||||
outputs = self._forward(x)
|
||||
return self.postprocess(outputs, orig_target_sizes)
|
||||
def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs):
|
||||
outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype))
|
||||
return self.postprocess(outputs, orig_size)
|
||||
|
||||
32
comfy/ops.py
32
comfy/ops.py
@ -483,6 +483,35 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
offload_stream = None
|
||||
|
||||
x = torch.nn.functional.batch_norm(
|
||||
input,
|
||||
comfy.model_management.cast_to(self.running_mean, dtype=input.dtype, device=input.device),
|
||||
comfy.model_management.cast_to(self.running_var, dtype=input.dtype, device=input.device),
|
||||
weight, bias, self.training or not self.track_running_stats,
|
||||
self.momentum, self.eps
|
||||
)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
@ -596,6 +625,9 @@ class manual_cast(disable_weight_init):
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class BatchNorm2d(disable_weight_init.BatchNorm2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
"""
|
||||
|
||||
@ -34,24 +34,18 @@ class RTDETR_detect(io.ComfyNode):
|
||||
|
||||
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
||||
|
||||
device = comfy.model_management.get_torch_device()
|
||||
dtype = model.model.get_dtype_inference()
|
||||
orig_size = torch.tensor([[W, H]], device=device, dtype=dtype).expand(B, -1) # [B, 2] as (W, H)
|
||||
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
results = model.model.diffusion_model(image_in.to(device=device, dtype=dtype), orig_size) # list of B dicts
|
||||
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
|
||||
|
||||
all_bbox_dicts = []
|
||||
|
||||
def _postprocess(results, threshold=0.5):
|
||||
det = results[0]
|
||||
keep = det['scores'] > threshold
|
||||
return det['boxes'][keep].cpu(), det['labels'][keep].cpu(), det['scores'][keep].cpu()
|
||||
for det in results:
|
||||
keep = det['scores'] > threshold
|
||||
boxes = det['boxes'][keep].cpu()
|
||||
labels = det['labels'][keep].cpu()
|
||||
scores = det['scores'][keep].cpu()
|
||||
|
||||
for i in range(B):
|
||||
boxes, labels, scores = _postprocess(results[i:i+1], threshold=threshold)
|
||||
|
||||
bbox_dicts = sorted([
|
||||
bbox_dicts = [
|
||||
{
|
||||
"x": float(box[0]),
|
||||
"y": float(box[1]),
|
||||
@ -62,8 +56,9 @@ class RTDETR_detect(io.ComfyNode):
|
||||
}
|
||||
for box, label, score in zip(boxes, labels, scores)
|
||||
if class_name == "all" or COCO_CLASSES[int(label)] == class_name
|
||||
], key=lambda d: d["score"], reverse=True)[:max_detections]
|
||||
all_bbox_dicts.append(bbox_dicts)
|
||||
]
|
||||
bbox_dicts.sort(key=lambda d: d["score"], reverse=True)
|
||||
all_bbox_dicts.append(bbox_dicts[:max_detections])
|
||||
|
||||
return io.NodeOutput(all_bbox_dicts)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user