mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 20:42:31 +08:00
structure generation works
This commit is contained in:
parent
0ebeac98a7
commit
ea255543e6
@ -1,6 +1,7 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
@ -274,8 +275,11 @@ class DINOv3ViTModel(nn.Module):
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
norm = self.norm.to(hidden_states.device)
|
||||
sequence_output = norm(hidden_states)
|
||||
if kwargs.get("skip_norm_elementwise", False):
|
||||
sequence_output= F.layer_norm(hidden_states, hidden_states.shape[-1:])
|
||||
else:
|
||||
norm = self.norm.to(hidden_states.device)
|
||||
sequence_output = norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
|
||||
return sequence_output, None, pooled_output, None
|
||||
|
||||
@ -77,7 +77,7 @@ class LayerNorm32(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
w = self.weight.to(torch.float32)
|
||||
w = self.weight.to(torch.float32) if self.weight is not None else None
|
||||
b = self.bias.to(torch.float32) if self.bias is not None else None
|
||||
|
||||
o = F.layer_norm(x, self.normalized_shape, w, b, self.eps)
|
||||
|
||||
@ -265,13 +265,13 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
||||
|
||||
model_internal.image_size = 512
|
||||
input_512 = prepare_tensor(cropped_img_tensor, 512)
|
||||
cond_512 = model_internal(input_512)[0]
|
||||
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
|
||||
|
||||
cond_1024 = None
|
||||
if include_1024:
|
||||
model_internal.image_size = 1024
|
||||
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
|
||||
cond_1024 = model_internal(input_1024)[0]
|
||||
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
|
||||
|
||||
conditioning = {
|
||||
'cond_512': cond_512.to(device),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user