Fixing missing __init__.py and other errors that appear when using an IDE

This commit is contained in:
doctorpangloss 2024-06-18 21:16:37 -07:00
parent ef38598030
commit 2aecff06ff
4 changed files with 9 additions and 9 deletions

View File

View File

@ -4,8 +4,8 @@ import torch
from torch import nn from torch import nn
from typing import Literal, Dict, Any from typing import Literal, Dict, Any
import math import math
import comfy.ops from ... import ops
ops = comfy.ops.disable_weight_init ops = ops.disable_weight_init
def vae_sample(mean, scale): def vae_sample(mean, scale):
stdev = nn.functional.softplus(scale) + 1e-4 stdev = nn.functional.softplus(scale) + 1e-4

View File

@ -1,6 +1,7 @@
# code adapted from: https://github.com/Stability-AI/stable-audio-tools # code adapted from: https://github.com/Stability-AI/stable-audio-tools
from einops.layers.torch import Rearrange
from comfy.ldm.modules.attention import optimized_attention from ..modules.attention import optimized_attention
import typing as tp import typing as tp
import torch import torch
@ -153,6 +154,8 @@ class RotaryEmbedding(nn.Module):
return self.forward(t) return self.forward(t)
def forward(self, t): def forward(self, t):
# todo: ???
seq_len = 0
# device = self.inv_freq.device # device = self.inv_freq.device
device = t.device device = t.device
dtype = t.dtype dtype = t.dtype
@ -343,7 +346,7 @@ class Attention(nn.Module):
# determine masking # determine masking
masks = [] masks = []
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account # todo: ???
if input_mask is not None: if input_mask is not None:
input_mask = rearrange(input_mask, 'b j -> b 1 1 j') input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
@ -351,9 +354,6 @@ class Attention(nn.Module):
# Other masks will be added here later # Other masks will be added here later
if len(masks) > 0:
final_attn_mask = ~or_reduce(masks)
n, device = q.shape[-2], q.device n, device = q.shape[-2], q.device
causal = self.causal if causal is None else causal causal = self.causal if causal is None else causal

View File

@ -6,7 +6,7 @@ from torch import Tensor, einsum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from einops import rearrange from einops import rearrange
import math import math
import comfy.ops from ... import ops
class LearnedPositionalEmbedding(nn.Module): class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time""" """Used for continuous time"""
@ -27,7 +27,7 @@ class LearnedPositionalEmbedding(nn.Module):
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential( return nn.Sequential(
LearnedPositionalEmbedding(dim), LearnedPositionalEmbedding(dim),
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features), ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
) )