structure generation works

This commit is contained in:
Yousef Rafat 2026-04-10 14:24:07 +02:00
parent 0ebeac98a7
commit ea255543e6
3 changed files with 9 additions and 5 deletions

View File

@ -1,6 +1,7 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
@ -274,8 +275,11 @@ class DINOv3ViTModel(nn.Module):
position_embeddings=position_embeddings,
)
norm = self.norm.to(hidden_states.device)
sequence_output = norm(hidden_states)
if kwargs.get("skip_norm_elementwise", False):
sequence_output= F.layer_norm(hidden_states, hidden_states.shape[-1:])
else:
norm = self.norm.to(hidden_states.device)
sequence_output = norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return sequence_output, None, pooled_output, None

View File

@ -77,7 +77,7 @@ class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = x.to(torch.float32)
w = self.weight.to(torch.float32)
w = self.weight.to(torch.float32) if self.weight is not None else None
b = self.bias.to(torch.float32) if self.bias is not None else None
o = F.layer_norm(x, self.normalized_shape, w, b, self.eps)

View File

@ -265,13 +265,13 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True):
model_internal.image_size = 512
input_512 = prepare_tensor(cropped_img_tensor, 512)
cond_512 = model_internal(input_512)[0]
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
cond_1024 = None
if include_1024:
model_internal.image_size = 1024
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
cond_1024 = model_internal(input_1024)[0]
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
conditioning = {
'cond_512': cond_512.to(device),