Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-08-05 16:13:20 -07:00
commit 39c6335331
20 changed files with 228 additions and 64 deletions

View File

@ -157,7 +157,11 @@ class PromptServer(ExecutorToClientProgress):
@routes.get("/")
async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html"))
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-cache'
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
@routes.get("/embeddings")
def get_embeddings(self):

View File

@ -9,6 +9,7 @@ import torch.nn.functional as F
from ..modules.attention import optimized_attention
from ... import ops
from .. import common_dit
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@ -407,10 +408,7 @@ class MMDiT(nn.Module):
def patchify(self, x):
B, C, H, W = x.size()
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
x = common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
x = x.view(
B,
C,

8
comfy/ldm/common_dit.py Normal file
View File

@ -0,0 +1,8 @@
import torch
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)

View File

@ -15,7 +15,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if model_management.is_device_mps(pos.device):
if model_management.is_device_mps(pos.device) or model_management.is_intel_xpu():
device = torch.device("cpu")
else:
device = pos.device

View File

@ -15,6 +15,7 @@ from .layers import (
)
from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass
class FluxParams:
@ -42,7 +43,7 @@ class Flux(nn.Module):
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
@ -125,10 +126,7 @@ class Flux(nn.Module):
def forward(self, x, timestep, context, y, guidance, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
pad_h = (patch_size - h % 2) % patch_size
pad_w = (patch_size - w % 2) % patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

View File

@ -10,6 +10,7 @@ from .. import attention
from einops import rearrange, repeat
from .util import timestep_embedding
from .... import ops
from ... import common_dit
def default(x, y):
if x is not None:
@ -112,9 +113,7 @@ class PatchEmbed(nn.Module):
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
# )
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode)
x = common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC

View File

@ -286,4 +286,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
if isinstance(model, model_base.Flux): #Diffusers lora Flux
diffusers_keys = utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
key_map[key_lora] = to
return key_map

View File

@ -87,6 +87,7 @@ class BaseModel(torch.nn.Module):
# todo: ???
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@ -95,6 +96,9 @@ class BaseModel(torch.nn.Module):
self.adm_channels = 0
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
@ -256,11 +260,11 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype
# TODO: this needs to be tweaked
area = input_shape[0] * math.prod(input_shape[2:])
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
return (area * model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
else:
# TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * math.prod(input_shape[2:])
return (area * 0.3) * (1024 * 1024)
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
@ -614,17 +618,6 @@ class SD3(BaseModel):
out['c_crossattn'] = conds.CONDRegular(cross_attn)
return out
def memory_required(self, input_shape):
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
# TODO: this probably needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
else:
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)
class AuraFlow(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@ -722,15 +715,3 @@ class Flux(BaseModel):
out['c_crossattn'] = conds.CONDRegular(cross_attn)
out['guidance'] = conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out
def memory_required(self, input_shape):
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this probably needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * model_management.dtype_size(dtype) * 0.020) * (1024 * 1024)
else:
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)

View File

@ -130,7 +130,7 @@ def detect_unet_config(state_dict, key_prefix):
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 64
dit_config["in_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072

View File

@ -511,7 +511,7 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0,
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)))
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
lowvram_model_memory = 0
@ -602,6 +602,9 @@ def unet_initial_load_device(parameters, dtype):
return cpu_dev
def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if args.bf16_unet:
return torch.bfloat16
@ -611,6 +614,21 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
fp8_dtype = None
try:
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if dtype in supported_dtypes:
fp8_dtype = dtype
break
except:
pass
if fp8_dtype is not None:
free_model_memory = maximum_vram_for_weights(device)
if model_params * 2 > free_model_memory:
return fp8_dtype
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
@ -973,7 +991,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
fp16_works = True
if fp16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@ -1016,21 +1034,14 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
if device is None:
device = torch.device("cuda")
try:
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
except AssertionError:
logging.warning("Torch was not compiled with CUDA support")
return False
props = torch.cuda.get_device_properties("cuda")
if props.major >= 8:
return True
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True

View File

@ -174,7 +174,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break

View File

@ -535,13 +535,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
unet_weight_dtype.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)

View File

@ -31,6 +31,7 @@ class SD15(supported_models_base.BASE):
}
latent_format = latent_formats.SD15
memory_usage_factor = 1.0
def process_clip_state_dict(self, state_dict):
k = list(state_dict.keys())
@ -77,6 +78,7 @@ class SD20(supported_models_base.BASE):
}
latent_format = latent_formats.SD15
memory_usage_factor = 1.0
def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
@ -140,6 +142,7 @@ class SDXLRefiner(supported_models_base.BASE):
}
latent_format = latent_formats.SDXL
memory_usage_factor = 1.0
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXLRefiner(self, device=device)
@ -178,6 +181,8 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
memory_usage_factor = 0.7
def model_type(self, state_dict, prefix=""):
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
self.latent_format = latent_formats.SDXL_Playground_2_5()
@ -505,6 +510,9 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
memory_usage_factor = 1.2
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
@ -631,6 +639,9 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Flux
memory_usage_factor = 2.6
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
@ -641,7 +652,13 @@ class Flux(supported_models_base.BASE):
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(flux.FluxTokenizer, flux.FluxClipModel)
pref = self.text_encoder_key_prefix[0]
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
if t5_key in state_dict:
dtype_t5 = state_dict[t5_key].dtype
else:
dtype_t5 = None
return supported_models_base.ClipTarget(flux.FluxTokenizer, flux.flux_clip(dtype_t5=dtype_t5))
class FluxSchnell(Flux):
unet_config = {

View File

@ -27,6 +27,8 @@ class BASE:
text_encoder_key_prefix = ["cond_stage_model."]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
memory_usage_factor = 2.0
manual_cast_dtype = None
@classmethod

View File

@ -56,9 +56,9 @@ class FluxClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return t5_out, l_pooled

View File

@ -87,7 +87,7 @@ class SD3ClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
lg_out = None
pooled = None
out = None
@ -114,7 +114,7 @@ class SD3ClipModel(torch.nn.Module):
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None:
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:

View File

@ -74,9 +74,21 @@ def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
params += sd[k].nelement()
w = sd[k]
params += w.nelement()
return params
def weight_dtype(sd, prefix=""):
dtypes = {}
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
if len(dtypes) == 0:
return None
return max(dtypes, key=dtypes.get)
def state_dict_key_replace(state_dict, keys_to_replace):
for x in keys_to_replace:
@ -443,6 +455,59 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def flux_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("depth", 0)
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
hidden_size = mmdit_config.get("hidden_size", 0)
key_map = {}
for index in range(n_double_layers):
prefix_from = "transformer_blocks.{}".format(index)
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
for end in ("weight", "bias"):
k = "{}.attn.".format(prefix_from)
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
}
for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
for index in range(n_single_layers):
prefix_from = "single_transformer_blocks.{}".format(index)
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
for end in ("weight", "bias"):
k = "{}.attn.".format(prefix_from)
qkv = "{}.linear1.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
block_map = {#TODO
}
for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
MAP_BASIC = { #TODO
}
for k in MAP_BASIC:
if len(k) > 2:
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
else:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:

View File

@ -3,6 +3,9 @@ import comfy.model_sampling
import comfy.latent_formats
import torch
from comfy.nodes.common import MAX_RESOLUTION
class LCM(comfy.model_sampling.EPS):
def timestep(self, *args, **kwargs) -> torch.Tensor:
pass
@ -173,6 +176,42 @@ class ModelSamplingAuraFlow(ModelSamplingSD3):
def patch_aura(self, model, shift):
return self.patch(model, shift, multiplier=1.0)
class ModelSamplingFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
"width": ("INT", {"default": 1024, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, width, height):
m = model.clone()
x1 = 256
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (width * height / (8 * 8 * 2 * 2)) * mm + b
sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
@ -289,5 +328,6 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingStableCascade": ModelSamplingStableCascade,
"ModelSamplingSD3": ModelSamplingSD3,
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"ModelSamplingFlux": ModelSamplingFlux,
"RescaleCFG": RescaleCFG,
}

View File

@ -1,4 +1,6 @@
from . import nodes_model_merging
from .nodes_model_merging import ModelMergeBlocks
class ModelMergeSD1(nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@ -75,9 +77,36 @@ class ModelMergeSD3_2B(nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeFlux1(ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["img_in."] = argument
arg_dict["time_in."] = argument
arg_dict["guidance_in"] = argument
arg_dict["vector_in."] = argument
arg_dict["txt_in."] = argument
for i in range(19):
arg_dict["double_blocks.{}.".format(i)] = argument
for i in range(38):
arg_dict["single_blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeFlux1": ModelMergeFlux1,
}

View File

@ -33,10 +33,9 @@ class EmptySD3LatentImage:
@classmethod
def INPUT_TYPES(s):
return {"required": {"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
return {"required": {"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"