Fixing missing __init__.py and other errors that appear when using an IDE

This commit is contained in:
doctorpangloss 2024-06-18 21:16:37 -07:00
parent ef38598030
commit 2aecff06ff
4 changed files with 9 additions and 9 deletions

View File

View File

@ -4,8 +4,8 @@ import torch
from torch import nn
from typing import Literal, Dict, Any
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
from ... import ops
ops = ops.disable_weight_init
def vae_sample(mean, scale):
stdev = nn.functional.softplus(scale) + 1e-4

View File

@ -1,6 +1,7 @@
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
from einops.layers.torch import Rearrange
from comfy.ldm.modules.attention import optimized_attention
from ..modules.attention import optimized_attention
import typing as tp
import torch
@ -153,6 +154,8 @@ class RotaryEmbedding(nn.Module):
return self.forward(t)
def forward(self, t):
# todo: ???
seq_len = 0
# device = self.inv_freq.device
device = t.device
dtype = t.dtype
@ -343,7 +346,7 @@ class Attention(nn.Module):
# determine masking
masks = []
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
# todo: ???
if input_mask is not None:
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
@ -351,9 +354,6 @@ class Attention(nn.Module):
# Other masks will be added here later
if len(masks) > 0:
final_attn_mask = ~or_reduce(masks)
n, device = q.shape[-2], q.device
causal = self.causal if causal is None else causal

View File

@ -6,7 +6,7 @@ from torch import Tensor, einsum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from einops import rearrange
import math
import comfy.ops
from ... import ops
class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""
@ -27,7 +27,7 @@ class LearnedPositionalEmbedding(nn.Module):
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
)