wip make package structure coherent

This commit is contained in:
Benjamin Berman 2023-08-22 11:35:20 -07:00
parent bc4e52f790
commit d25d5e75f0
37 changed files with 216 additions and 196 deletions

View File

@ -14,15 +14,15 @@
__version__ = "1.0.0"
# import ApiClient
from comfy.api.api_client import ApiClient
from api_client import ApiClient
# import Configuration
from comfy.api.configuration import Configuration
from configuration import Configuration
# import exceptions
from comfy.api.exceptions import OpenApiException
from comfy.api.exceptions import ApiAttributeError
from comfy.api.exceptions import ApiTypeError
from comfy.api.exceptions import ApiValueError
from comfy.api.exceptions import ApiKeyError
from comfy.api.exceptions import ApiException
from exceptions import OpenApiException
from exceptions import ApiAttributeError
from exceptions import ApiTypeError
from exceptions import ApiValueError
from exceptions import ApiKeyError
from exceptions import ApiException

View File

@ -0,0 +1,2 @@
class Configuration:
pass

22
comfy/api/exceptions.py Normal file
View File

@ -0,0 +1,22 @@
class OpenApiException:
pass
class ApiAttributeError:
pass
class ApiTypeError:
pass
class ApiValueError:
pass
class ApiKeyError:
pass
class ApiException:
pass

View File

@ -13,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists
import comfy.ops
from .. import ops
class ControlledUnetModel(UNetModel):
#implemented in the ldm unet
@ -54,7 +54,7 @@ class ControlNet(nn.Module):
adm_in_channels=None,
transformer_depth_middle=None,
device=None,
operations=comfy.ops,
operations=ops,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"

View File

@ -2,12 +2,12 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPIm
from .utils import load_torch_file, transformers_convert
import os
import torch
import comfy.ops
from . import ops
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
with comfy.ops.use_comfy_ops():
with ops.use_comfy_ops():
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.processor = CLIPImageProcessor(crop_size=224,

View File

@ -1,6 +1,6 @@
import os
import importlib.util
from comfy.cli_args import args
from ..cli_args import args
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():

View File

@ -14,9 +14,9 @@ import sys
import torch
from comfy.nodes.package import import_all_nodes_in_workspace
from ..nodes.package import import_all_nodes_in_workspace
nodes = import_all_nodes_in_workspace()
import comfy.model_management
from .. import model_management
"""
A queued item
@ -112,7 +112,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
results = []
if input_is_list:
if allow_interrupt:
comfy.model_management.throw_exception_if_processing_interrupted()
model_management.throw_exception_if_processing_interrupted()
results.append(getattr(obj, func)(**input_data_all))
elif max_len_input == 0:
if allow_interrupt:
@ -121,7 +121,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
else:
for i in range(max_len_input):
if allow_interrupt:
comfy.model_management.throw_exception_if_processing_interrupted()
model_management.throw_exception_if_processing_interrupted()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results
@ -207,7 +207,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if server.client_id is not None:
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id},
server.client_id)
except comfy.model_management.InterruptProcessingException as iex:
except model_management.InterruptProcessingException as iex:
print("Processing interrupted")
# skip formatting inputs/outputs
@ -332,7 +332,7 @@ class PromptExecutor:
# First, send back the status to the frontend depending
# on the exception type
if isinstance(ex, comfy.model_management.InterruptProcessingException):
if isinstance(ex, model_management.InterruptProcessingException):
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
@ -369,7 +369,7 @@ class PromptExecutor:
del d
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
comfy.model_management.interrupt_current_processing(False)
model_management.interrupt_current_processing(False)
if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"]
@ -409,7 +409,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x)
del d
comfy.model_management.cleanup_models()
model_management.cleanup_models()
if self.server.client_id is not None:
self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id},
self.server.client_id)

View File

@ -1,8 +1,8 @@
import torch
from PIL import Image
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
from ..cli_args import args, LatentPreviewMethod
from ..taesd.taesd import TAESD
from ..cmd import folder_paths
MAX_PREVIEW_RESOLUTION = 512

View File

@ -1,7 +1,7 @@
import os
import importlib.util
from comfy.cmd import cuda_malloc
from ..cmd import cuda_malloc
from ..cmd import folder_paths
import time
@ -52,7 +52,7 @@ import shutil
import threading
import gc
from comfy.cli_args import args
from ..cli_args import args
if os.name == "nt":
import logging
@ -63,13 +63,13 @@ if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device)
import comfy.utils
from .. import utils
import yaml
from ..cmd import execution
from ..cmd import server as server_module
from .server import BinaryEventTypes
import comfy.model_management
from .. import model_management
def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer):
@ -85,7 +85,7 @@ def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
gc.collect()
comfy.model_management.soft_empty_cache()
model_management.soft_empty_cache()
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
@ -96,7 +96,7 @@ def hijack_progress(server):
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook)
utils.set_progress_bar_global_hook(hook)
def cleanup_temp():
@ -127,8 +127,8 @@ def load_extra_path_config(yaml_path):
def cuda_malloc_warning():
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
device = model_management.get_torch_device()
device_name = model_management.get_torch_device_name(device)
cuda_malloc_warning = False
if "cudaMallocAsync" in device_name:
for b in cuda_malloc.blacklist:

View File

@ -23,12 +23,12 @@ from ..cmd import execution
from ..cmd import folder_paths
import mimetypes
from comfy.digest import digest
from comfy.cli_args import args
import comfy.utils
import comfy.model_management
from comfy.nodes.package import import_all_nodes_in_workspace
from comfy.vendor.appdirs import user_data_dir
from ..digest import digest
from ..cli_args import args
from .. import utils
from .. import model_management
from ..nodes.package import import_all_nodes_in_workspace
from ..vendor.appdirs import user_data_dir
nodes = import_all_nodes_in_workspace()
@ -358,7 +358,7 @@ class PromptServer():
safetensors_path = folder_paths.get_full_path(folder_name, filename)
if safetensors_path is None:
return web.Response(status=404)
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024 * 1024)
out = utils.safetensors_header(safetensors_path, max_size=1024 * 1024)
if out is None:
return web.Response(status=404)
dt = json.loads(out)
@ -368,10 +368,10 @@ class PromptServer():
@routes.get("/system_stats")
async def get_queue(request):
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
device = model_management.get_torch_device()
device_name = model_management.get_torch_device_name(device)
vram_total, torch_vram_total = model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"system": {
"os": os.name,
@ -507,7 +507,7 @@ class PromptServer():
@routes.post("/interrupt")
async def post_interrupt(request):
comfy.model_management.interrupt_current_processing()
model_management.interrupt_current_processing()
return web.Response(status=200)
@routes.post("/history")
@ -654,7 +654,7 @@ class PromptServer():
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([
web.static('/extensions/' + urllib.parse.quote(name), dir, follow_symlinks=True),
web.static('/extensions/' + quote(name), dir, follow_symlinks=True),
])
self.app.add_routes([

View File

@ -2,7 +2,7 @@ import json
import yaml
from .cmd import folder_paths
from comfy.sd import load_checkpoint
from .sd import load_checkpoint
import os.path as osp
import torch
from safetensors.torch import load_file

View File

@ -3,11 +3,11 @@ import torch
import torch.nn.functional as F
from contextlib import contextmanager
from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ..modules.diffusionmodules.model import Encoder, Decoder
from ..modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config
from comfy.ldm.modules.ema import LitEma
from ..util import instantiate_from_config
from ..modules.ema import LitEma
# class AutoencoderKL(pl.LightningModule):
class AutoencoderKL(torch.nn.Module):

View File

@ -4,7 +4,7 @@ import torch
import numpy as np
from tqdm import tqdm
from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
from ...modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):

View File

@ -5,8 +5,8 @@ import numpy as np
from tqdm import tqdm
from functools import partial
from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from comfy.ldm.models.diffusion.sampling_util import norm_thresholding
from ...modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from .sampling_util import norm_thresholding
class PLMSSampler(object):

View File

@ -8,15 +8,14 @@ from typing import Optional, Any
from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
from ... import model_management
if model_management.xformers_enabled():
import xformers
import xformers.ops
from comfy.cli_args import args
import comfy.ops
from ...cli_args import args
from ... import ops
# CrossAttn precision handling
if args.dont_upcast_attention:
@ -53,7 +52,7 @@ def init_(tensor):
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
@ -63,7 +62,7 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
@ -149,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -246,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
class CrossAttentionDoggettx(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -344,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
return self.to_out(r2)
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -400,7 +399,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=ops):
super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.")
@ -451,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return self.to_out(out)
class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@ -509,7 +508,7 @@ else:
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
disable_self_attn=False, dtype=None, device=None, operations=ops):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
@ -649,7 +648,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
use_checkpoint=True, dtype=None, device=None, operations=ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth

View File

@ -7,8 +7,8 @@ from einops import rearrange
from typing import Optional, Any
from ..attention import MemoryEfficientCrossAttention
from comfy import model_management
import comfy.ops
from .... import model_management
from .... import ops
if model_management.xformers_enabled_vae():
import xformers
@ -49,7 +49,7 @@ class Upsample(nn.Module):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = comfy.ops.Conv2d(in_channels,
self.conv = ops.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
@ -68,7 +68,7 @@ class Downsample(nn.Module):
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = comfy.ops.Conv2d(in_channels,
self.conv = ops.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
@ -96,30 +96,30 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = comfy.ops.Conv2d(in_channels,
self.conv1 = ops.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = comfy.ops.Linear(temb_channels,
self.temb_proj = ops.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = comfy.ops.Conv2d(out_channels,
self.conv2 = ops.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
self.conv_shortcut = ops.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
self.nin_shortcut = ops.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
@ -189,22 +189,22 @@ class AttnBlock(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
self.q = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
self.k = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
self.v = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
self.proj_out = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
@ -244,22 +244,22 @@ class MemoryEfficientAttnBlock(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
self.q = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
self.k = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
self.v = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
self.proj_out = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
@ -303,22 +303,22 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
self.q = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
self.k = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
self.v = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
self.proj_out = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
@ -400,14 +400,14 @@ class Model(nn.Module):
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
comfy.ops.Linear(self.ch,
ops.Linear(self.ch,
self.temb_ch),
comfy.ops.Linear(self.temb_ch,
ops.Linear(self.temb_ch,
self.temb_ch),
])
# downsampling
self.conv_in = comfy.ops.Conv2d(in_channels,
self.conv_in = ops.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
@ -476,7 +476,7 @@ class Model(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = ops.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
@ -549,7 +549,7 @@ class Encoder(nn.Module):
self.in_channels = in_channels
# downsampling
self.conv_in = comfy.ops.Conv2d(in_channels,
self.conv_in = ops.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
@ -594,7 +594,7 @@ class Encoder(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = ops.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
@ -654,7 +654,7 @@ class Decoder(nn.Module):
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = comfy.ops.Conv2d(z_channels,
self.conv_in = ops.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
@ -696,7 +696,7 @@ class Decoder(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = ops.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,

View File

@ -14,8 +14,8 @@ from .util import (
timestep_embedding,
)
from ..attention import SpatialTransformer
from comfy.ldm.util import exists
import comfy.ops
from ...util import exists
from .... import ops
class TimestepBlock(nn.Module):
"""
@ -70,7 +70,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@ -106,7 +106,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@ -156,7 +156,7 @@ class ResBlock(TimestepBlock):
down=False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
super().__init__()
self.channels = channels
@ -316,7 +316,7 @@ class UNetModel(nn.Module):
adm_in_channels=None,
transformer_depth_middle=None,
device=None,
operations=comfy.ops,
operations=ops,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"

View File

@ -4,7 +4,7 @@ import numpy as np
from functools import partial
from .util import extract_into_tensor, make_beta_schedule
from comfy.ldm.util import default
from ...util import default
class AbstractLowScaleModel(nn.Module):

View File

@ -15,8 +15,8 @@ import torch.nn as nn
import numpy as np
from einops import repeat
from comfy.ldm.util import instantiate_from_config
import comfy.ops
from ...util import instantiate_from_config
from .... import ops
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
@ -233,7 +233,7 @@ def conv_nd(dims, *args, **kwargs):
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return comfy.ops.Conv2d(*args, **kwargs)
return ops.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
@ -243,7 +243,7 @@ def linear(*args, **kwargs):
"""
Create a linear module.
"""
return comfy.ops.Linear(*args, **kwargs)
return ops.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):

View File

@ -24,7 +24,7 @@ except ImportError:
from torch import Tensor
from typing import List
from comfy import model_management
from ... import model_management
def dynamic_slice(
x: Tensor,

View File

@ -1,8 +1,8 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
from .ldm.modules.diffusionmodules.openaimodel import UNetModel
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.modules.diffusionmodules.util import make_beta_schedule
from .ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np
from enum import Enum
from . import utils

View File

@ -1,6 +1,6 @@
import psutil
from enum import Enum
from comfy.cli_args import args
from .cli_args import args
import torch
import sys

View File

@ -11,19 +11,17 @@ from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
import comfy.diffusers_load
import comfy.samplers
import comfy.sample
import comfy.sd
import comfy.utils
from .. import diffusers_load
from .. import samplers
from .. import sample
from .. import sd
from .. import utils
from .. import clip_vision as clip_vision_module
from .. import model_management
from ..cli_args import args
import comfy.clip_vision
import comfy.model_management
from comfy.cli_args import args
from comfy.cmd import folder_paths, latent_preview
from comfy.nodes.common import MAX_RESOLUTION
from ..cmd import folder_paths, latent_preview
from ..nodes.common import MAX_RESOLUTION
class CLIPTextEncode:
@ -361,7 +359,7 @@ class SaveLatent:
output["latent_tensor"] = samples["samples"]
output["latent_format_version_0"] = torch.tensor([])
comfy.utils.save_torch_file(output, file, metadata=metadata)
utils.save_torch_file(output, file, metadata=metadata)
return { "ui": { "latents": results } }
@ -414,7 +412,7 @@ class CheckpointLoader:
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple:
@classmethod
@ -428,7 +426,7 @@ class CheckpointLoaderSimple:
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
class DiffusersLoader:
@ -455,7 +453,7 @@ class DiffusersLoader:
model_path = path
break
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return diffusers_load.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader:
@ -470,7 +468,7 @@ class unCLIPCheckpointLoader:
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
class CLIPSetLastLayer:
@ -521,10 +519,10 @@ class LoraLoader:
del temp
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
lora = utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
return (model_lora, clip_lora)
class VAELoader:
@ -539,7 +537,7 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name)
vae = comfy.sd.VAE(ckpt_path=vae_path)
vae = sd.VAE(ckpt_path=vae_path)
return (vae,)
class ControlNetLoader:
@ -554,7 +552,7 @@ class ControlNetLoader:
def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path)
controlnet = sd.load_controlnet(controlnet_path)
return (controlnet,)
class DiffControlNetLoader:
@ -570,7 +568,7 @@ class DiffControlNetLoader:
def load_controlnet(self, model, control_net_name):
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path, model)
controlnet = sd.load_controlnet(controlnet_path, model)
return (controlnet,)
@ -663,7 +661,7 @@ class UNETLoader:
def load_unet(self, unet_name):
unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_unet(unet_path)
model = sd.load_unet(unet_path)
return (model,)
class CLIPLoader:
@ -678,7 +676,7 @@ class CLIPLoader:
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
class DualCLIPLoader:
@ -695,7 +693,7 @@ class DualCLIPLoader:
def load_clip(self, clip_name1, clip_name2):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
class CLIPVisionLoader:
@ -710,7 +708,7 @@ class CLIPVisionLoader:
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
clip_vision = comfy.clip_vision.load(clip_path)
clip_vision = clip_vision_module.load(clip_path)
return (clip_vision,)
class CLIPVisionEncode:
@ -740,7 +738,7 @@ class StyleModelLoader:
def load_style_model(self, style_model_name):
style_model_path = folder_paths.get_full_path("style_models", style_model_name)
style_model = comfy.sd.load_style_model(style_model_path)
style_model = sd.load_style_model(style_model_path)
return (style_model,)
@ -805,7 +803,7 @@ class GLIGENLoader:
def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path)
gligen = sd.load_gligen(gligen_path)
return (gligen,)
class GLIGENTextBoxApply:
@ -933,7 +931,7 @@ class LatentUpscale:
def upscale(self, samples, upscale_method, width, height, crop):
s = samples.copy()
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
s["samples"] = utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
return (s,)
class LatentUpscaleBy:
@ -952,7 +950,7 @@ class LatentUpscaleBy:
s = samples.copy()
width = round(samples["samples"].shape[3] * scale_by)
height = round(samples["samples"].shape[2] * scale_by)
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
s["samples"] = utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
return (s,)
class LatentRotate:
@ -1068,7 +1066,7 @@ class LatentBlend:
if samples1.shape != samples2.shape:
samples2.permute(0, 3, 1, 2)
samples2 = comfy.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
samples2 = utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
samples2.permute(0, 2, 3, 1)
samples_blended = self.blend_mode(samples1, samples2, blend_mode)
@ -1133,14 +1131,14 @@ class SetLatentNoiseMask:
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
device = comfy.model_management.get_torch_device()
device = model_management.get_torch_device()
latent_image = latent["samples"]
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise = sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
@ -1152,14 +1150,14 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
previewer = latent_preview.get_previewer(device, model.model.latent_format)
pbar = comfy.utils.ProgressBar(steps)
pbar = utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
samples = sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed)
out = latent.copy()
@ -1174,8 +1172,8 @@ class KSampler:
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"sampler_name": (samplers.KSampler.SAMPLERS, ),
"scheduler": (samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
@ -1200,8 +1198,8 @@ class KSamplerAdvanced:
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"sampler_name": (samplers.KSampler.SAMPLERS, ),
"scheduler": (samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
@ -1397,7 +1395,7 @@ class ImageScale:
def upscale(self, image, upscale_method, width, height, crop):
samples = image.movedim(-1,1)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
s = utils.common_upscale(samples, width, height, upscale_method, crop)
s = s.movedim(1,-1)
return (s,)
@ -1417,7 +1415,7 @@ class ImageScaleBy:
samples = image.movedim(-1,1)
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1,-1)
return (s,)
@ -1449,7 +1447,7 @@ class ImageBatch:
def batch(self, image1, image2):
if image1.shape[1:] != image2.shape[1:]:
image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
image2 = utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
s = torch.cat((image1, image2), dim=0)
return (s,)

View File

@ -6,14 +6,14 @@ import pkgutil
import time
import types
from comfy.nodes import base_nodes as base_nodes
from comfy_extras import nodes as comfy_extras_nodes
from . import base_nodes
from ...comfy_extras import nodes as comfy_extras_nodes
try:
import custom_nodes
except:
custom_nodes = None
from comfy.nodes.package_typing import ExportedNodes
from .package_typing import ExportedNodes
from functools import reduce
from pkg_resources import resource_filename

View File

@ -1,6 +1,6 @@
import torch
import comfy.model_management
import comfy.samplers
from . import model_management
from . import samplers
import math
import numpy as np
@ -71,14 +71,14 @@ def cleanup_additional_models(models):
m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
device = comfy.model_management.get_torch_device()
device = model_management.get_torch_device()
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None
models = get_additional_models(positive, negative)
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]))
model_management.load_models_gpu([model] + models, model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]))
real_model = model.model
noise = noise.to(device)
@ -88,7 +88,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
negative_copy = broadcast_cond(negative, noise.shape[0], device)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
sampler = samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()

View File

@ -2,11 +2,11 @@ from .k_diffusion import sampling as k_diffusion_sampling
from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc
import torch
from comfy import model_management
from . import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
from comfy import model_base
from . import model_base
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)

View File

@ -3,7 +3,7 @@ import contextlib
import copy
import inspect
from comfy import model_management
from . import model_management
from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL
import yaml

View File

@ -1,7 +1,7 @@
import os
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
import comfy.ops
from . import ops
import torch
import traceback
import zipfile
@ -57,7 +57,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
config = CLIPTextConfig.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with comfy.ops.use_comfy_ops():
with ops.use_comfy_ops():
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)

View File

@ -1,6 +1,6 @@
from pkg_resources import resource_filename
from comfy import sd1_clip
from . import sd1_clip
import torch
import os

View File

@ -1,4 +1,4 @@
from comfy import sd1_clip
from . import sd1_clip
import torch
import os

View File

@ -3,7 +3,7 @@ import os.path
import torch
import math
import struct
import comfy.checkpoint_pickle
from . import checkpoint_pickle
import safetensors.torch
def load_torch_file(ckpt, safe_load=False, device=None):
@ -19,7 +19,7 @@ def load_torch_file(ckpt, safe_load=False, device=None):
if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:

View File

@ -1,5 +1,5 @@
import torch
from comfy.nodes.common import MAX_RESOLUTION
from ...comfy.nodes.common import MAX_RESOLUTION
class CLIPTextEncodeSDXLRefiner:

View File

@ -1,9 +1,9 @@
import comfy.utils
from comfy.cmd import folder_paths
from ...comfy import utils
from ...comfy.cmd import folder_paths
import torch
def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True)
sd = utils.load_torch_file(path, safe_load=True)
activation_func = sd.get('activation_func', 'linear')
is_layer_norm = sd.get('is_layer_norm', False)
use_dropout = sd.get('use_dropout', False)

View File

@ -2,7 +2,7 @@ import numpy as np
from scipy.ndimage import grey_dilation
import torch
from comfy.nodes.common import MAX_RESOLUTION
from ...comfy.nodes.common import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):

View File

@ -1,12 +1,11 @@
import comfy.sd
import comfy.utils
import comfy.model_base
from ...comfy import sd
from ...comfy import model_base
from comfy.cmd import folder_paths
from ...comfy.cmd import folder_paths
import json
import os
from comfy.cli_args import args
from ...comfy.cli_args import args
class ModelMergeSimple:
@classmethod
@ -106,9 +105,9 @@ class CheckpointSave:
metadata = {}
enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
if isinstance(model.model, model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
elif isinstance(model.model, model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
@ -123,9 +122,9 @@ class CheckpointSave:
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == comfy.model_base.ModelType.EPS:
if model.model.model_type == model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
@ -137,7 +136,7 @@ class CheckpointSave:
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {}

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from PIL import Image
import math
import comfy.utils
from ...comfy import utils
class Blend:
@ -35,7 +35,7 @@ class Blend:
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
image2 = utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
image2 = image2.permute(0, 2, 3, 1)
blended_image = self.blend_mode(image1, image2, blend_mode)
@ -232,7 +232,7 @@ class ImageScaleToTotalPixels:
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1,-1)
return (s,)

View File

@ -1,8 +1,8 @@
from comfy_extras.chainner_models import model_loading
from comfy import model_management
from ..chainner_models import model_loading
from ...comfy import model_management
import torch
import comfy.utils
from comfy.cmd import folder_paths
from ...comfy import utils
from ...comfy.cmd import folder_paths
class UpscaleModelLoader:
@ -17,7 +17,7 @@ class UpscaleModelLoader:
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
sd = utils.load_torch_file(model_path, safe_load=True)
out = model_loading.load_state_dict(sd).eval()
return (out, )
@ -45,9 +45,9 @@ class ImageUpscaleWithModel:
oom = True
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
steps = in_img.shape[0] * utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = utils.ProgressBar(steps)
s = utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
tile //= 2