diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml new file mode 100644 index 000000000..421dd5ee4 --- /dev/null +++ b/.github/workflows/test-build.yml @@ -0,0 +1,31 @@ +name: Build package + +# +# This workflow is a test of the python package build. +# Install Python dependencies across different Python versions. +# + +on: + push: + paths: + - "requirements.txt" + - ".github/workflows/test-build.yml" + +jobs: + build: + name: Build Test + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt \ No newline at end of file diff --git a/.gitignore b/.gitignore index 7e373817d..98d91318d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,18 +1,16 @@ __pycache__/ *.py[cod] -output/ -input/ -!input/example.png -models/ -temp/ -custom_nodes/ +/output/ +/input/ +!/input/example.png +/models/ +/temp/ +/custom_nodes/ !custom_nodes/example_node.py.example extra_model_paths.yaml /.vs .idea/ venv/ -web/extensions/* -!web/extensions/logging.js.example -!web/extensions/core/ -comfyui.log -comfyui.prev.log +/web/extensions/* +!/web/extensions/logging.js.example +!/web/extensions/core/ diff --git a/comfy/cli_args.py b/comfy/cli_args.py index fda245433..ffae81c49 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,6 +1,6 @@ import argparse import enum - +import comfy.options class EnumAction(argparse.Action): """ @@ -94,7 +94,10 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") -args = parser.parse_args() +if comfy.options.args_parsing: + args = parser.parse_args() +else: + args = parser.parse_args([]) if args.windows_standalone_build: args.auto_launch = True diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 9b95ae003..1206c680d 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -56,6 +56,7 @@ class ClipVisionModel(): if t is not None: if k == 'hidden_states': outputs["penultimate_hidden_states"] = t[-2].cpu() + outputs["hidden_states"] = None else: outputs[k] = t.cpu() diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index eb088d92b..937c5a388 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -706,3 +706,34 @@ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disab noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) + +def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): + alpha_cumprod = 1 / ((sigma * sigma) + 1) + alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1) + alpha = (alpha_cumprod / alpha_cumprod_prev) + + mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) + if sigma_prev > 0: + mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) + return mu + + +def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler) + if sigmas[i + 1] != 0: + x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0) + return x + + +@torch.no_grad() +def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): + return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) + diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 8b59cfbdc..fadc0eec7 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -1,5 +1,9 @@ class LatentFormat: + scale_factor = 1.0 + latent_rgb_factors = None + taesd_decoder_name = None + def process_in(self, latent): return latent * self.scale_factor diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index 139c8e01e..befab0075 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -33,7 +33,6 @@ class DDIMSampler(object): assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) - self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) @@ -195,7 +194,7 @@ class DDIMSampler(object): temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False): - device = self.model.betas.device + device = self.model.alphas_cumprod.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) diff --git a/comfy/model_base.py b/comfy/model_base.py index ca154dba0..ed2dc83e4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -181,7 +181,7 @@ class SDXLRefiner(BaseModel): out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([aesthetic_score]))) - flat = torch.flatten(torch.cat(out))[None, ] + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) class SDXL(BaseModel): @@ -206,5 +206,5 @@ class SDXL(BaseModel): out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([target_height]))) out.append(self.embedder(torch.Tensor([target_width]))) - flat = torch.flatten(torch.cat(out))[None, ] + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) diff --git a/comfy/model_management.py b/comfy/model_management.py index b663e8f59..1050c13a4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -165,6 +165,9 @@ try: ENABLE_PYTORCH_ATTENTION = True if torch.cuda.is_bf16_supported(): VAE_DTYPE = torch.bfloat16 + if is_intel_xpu(): + if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + ENABLE_PYTORCH_ATTENTION = True except: pass @@ -451,6 +454,8 @@ def text_encoder_device(): if args.gpu_only: return get_torch_device() elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: + if is_intel_xpu(): + return torch.device("cpu") if should_use_fp16(prioritize_performance=False): return get_torch_device() else: @@ -476,6 +481,23 @@ def get_autocast_device(dev): return dev.type return "cuda" +def cast_to_device(tensor, device, dtype, copy=False): + device_supports_cast = False + if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: + device_supports_cast = True + elif tensor.dtype == torch.bfloat16: + if hasattr(device, 'type') and device.type.startswith("cuda"): + device_supports_cast = True + + if device_supports_cast: + if copy: + if tensor.device == device: + return tensor.to(dtype, copy=copy) + return tensor.to(device, copy=copy).to(dtype) + else: + return tensor.to(device).to(dtype) + else: + return tensor.to(dtype).to(device, copy=copy) def xformers_enabled(): global directml_enabled diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a6ee0bae1..10551656e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -3,6 +3,7 @@ import copy import inspect import comfy.utils +import comfy.model_management class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None): @@ -154,7 +155,7 @@ class ModelPatcher: self.backup[key] = weight.to(self.offload_device) if device_to is not None: - temp_weight = weight.float().to(device_to, copy=True) + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) @@ -185,15 +186,15 @@ class ModelPatcher: if w1.shape != weight.shape: print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: - weight += alpha * w1.type(weight.dtype).to(weight.device) + weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) elif len(v) == 4: #lora/locon - mat1 = v[0].float().to(weight.device) - mat2 = v[1].float().to(weight.device) + mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) + mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = v[3].float().to(weight.device) + mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: @@ -212,18 +213,23 @@ class ModelPatcher: if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(w1_a.float(), w1_b.float()) + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32)) else: - w1 = w1.float().to(weight.device) + w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32)) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32)) else: - w2 = w2.float().to(weight.device) + w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -244,11 +250,20 @@ class ModelPatcher: if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t1, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1a, weight.device, torch.float32)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2a, weight.device, torch.float32)) else: - m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) - m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1b, weight.device, torch.float32)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2b, weight.device, torch.float32)) try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) diff --git a/comfy/options.py b/comfy/options.py new file mode 100644 index 000000000..f7f8af41e --- /dev/null +++ b/comfy/options.py @@ -0,0 +1,6 @@ + +args_parsing = False + +def enable_args_parsing(enable=True): + global args_parsing + args_parsing = enable diff --git a/comfy/samplers.py b/comfy/samplers.py index c60288fd1..e3192ca58 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -7,6 +7,7 @@ from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps import math from comfy import model_base +import comfy.utils def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) return abs(a*b) // math.gcd(a, b) @@ -255,6 +256,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con else: transformer_options["patches"] = patches + transformer_options["cond_or_uncond"] = cond_or_uncond[:] c['transformer_options'] = transformer_options if 'model_function_wrapper' in model_options: @@ -537,7 +539,7 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type): if adm_out is not None: x[1] = x[1].copy() - x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device) + x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device) return conds @@ -546,7 +548,7 @@ class KSampler: SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model diff --git a/comfy/sd.py b/comfy/sd.py index 8be0bcbc8..9bdb2ad64 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -454,20 +454,26 @@ def load_unet(unet_path): #load unet in diffusers format sd = comfy.utils.load_torch_file(unet_path) parameters = comfy.utils.calculate_parameters(sd) fp16 = model_management.should_use_fp16(model_params=parameters) + if "input_blocks.0.0.weight" in sd: #ldm + model_config = model_detection.model_config_from_unet(sd, "", fp16) + if model_config is None: + raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + new_sd = sd - model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) - if model_config is None: - print("ERROR UNSUPPORTED UNET", unet_path) - return None + else: #diffusers + model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) + if model_config is None: + print("ERROR UNSUPPORTED UNET", unet_path) + return None - diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) + diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) - new_sd = {} - for k in diffusers_keys: - if k in sd: - new_sd[diffusers_keys[k]] = sd.pop(k) - else: - print(diffusers_keys[k], k) + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + print(diffusers_keys[k], k) offload_device = model_management.unet_offload_device() model = model_config.get_model(new_sd, "") model = model.to(offload_device) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 477d5c309..9978b6c35 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -60,6 +60,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): if dtype is not None: self.transformer.to(dtype) + self.transformer.text_model.embeddings.token_embedding.to(torch.float32) + self.transformer.text_model.embeddings.position_embedding.to(torch.float32) + self.max_length = max_length if freeze: self.freeze() @@ -68,6 +71,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.empty_tokens = [[49406] + [49407] * 76] self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.enable_attention_masks = False self.layer_norm_hidden_state = True if layer == "hidden": @@ -138,13 +142,23 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if backup_embeds.weight.dtype != torch.float32: + if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: precision_scope = torch.autocast else: precision_scope = lambda a, b: contextlib.nullcontext(a) with precision_scope(model_management.get_autocast_device(device), torch.float32): - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden") self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 818c9711e..05e50a005 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -12,16 +12,6 @@ class SD2ClipModel(sd1_clip.SD1ClipModel): super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] - def clip_layer(self, layer_idx): - if layer_idx < 0: - layer_idx -= 1 #The real last layer of SD2.x clip is the penultimate one. The last one might contain garbage. - if abs(layer_idx) >= 24: - self.layer = "hidden" - self.layer_idx = -2 - else: - self.layer = "hidden" - self.layer_idx = layer_idx - class SD2Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) diff --git a/comfy/utils.py b/comfy/utils.py index 3ed32e372..7843b58cc 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -3,6 +3,8 @@ import math import struct import comfy.checkpoint_pickle import safetensors.torch +import numpy as np +from PIL import Image def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -346,6 +348,13 @@ def bislerp(samples, width, height): result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result +def lanczos(samples, width, height): + images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] + images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] + images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] + result = torch.stack(images) + return result + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] @@ -364,6 +373,8 @@ def common_upscale(samples, width, height, upscale_method, crop): if upscale_method == "bislerp": return bislerp(s, width, height) + elif upscale_method == "lanczos": + return lanczos(s, width, height) else: return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index bce4b3dd0..3d42d7806 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -27,6 +27,44 @@ class ModelMergeSimple: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) +class ModelSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, model1, model2, multiplier): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return (m, ) + +class ModelAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, model1, model2): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, 1.0, 1.0) + return (m, ) + + class CLIPMergeSimple: @classmethod def INPUT_TYPES(s): @@ -144,6 +182,8 @@ class CheckpointSave: NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, "ModelMergeBlocks": ModelMergeBlocks, + "ModelMergeSubtract": ModelSubtract, + "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, } diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 51bdb24fa..3f651e594 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -211,7 +211,7 @@ class Sharpen: return (result,) class ImageScaleToTotalPixels: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index e37808b03..733014f3c 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -54,7 +54,13 @@ class Example: "step": 64, #Slider's step "display": "number" # Cosmetic only: display as "number" or "slider" }), - "float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "display": "number"}), + "float_field": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 10.0, + "step": 0.01, + "round": 0.001, #The value represeting the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. + "display": "number"}), "print_to_screen": (["enable", "disable"],), "string_field": ("STRING", { "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node diff --git a/folder_paths.py b/folder_paths.py index 82aedd43f..4a10c68e7 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,14 +1,13 @@ import os import time -supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) folder_names_and_paths = {} base_path = os.path.dirname(os.path.realpath(__file__)) models_dir = os.path.join(base_path, "models") -folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions) +folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions) folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) diff --git a/latent_preview.py b/latent_preview.py index 30c1d1317..87240a582 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -53,7 +53,9 @@ def get_previewer(device, latent_format): method = args.preview_method if method != LatentPreviewMethod.NoPreviews: # TODO previewer methods - taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) + taesd_decoder_path = None + if latent_format.taesd_decoder_name is not None: + taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB @@ -68,7 +70,8 @@ def get_previewer(device, latent_format): print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) if previewer is None: - previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) + if latent_format.latent_rgb_factors is not None: + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) return previewer diff --git a/main.py b/main.py index 9f0f80458..7c5eaee0a 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,6 @@ +import comfy.options +comfy.options.enable_args_parsing() + import os import importlib.util import folder_paths diff --git a/nodes.py b/nodes.py index 77d180526..18d82ea80 100644 --- a/nodes.py +++ b/nodes.py @@ -543,8 +543,8 @@ class LoraLoader: return {"required": { "model": ("MODEL",), "clip": ("CLIP", ), "lora_name": (folder_paths.get_filename_list("loras"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - "strength_clip": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL", "CLIP") FUNCTION = "load_lora" @@ -889,8 +889,8 @@ class EmptyLatentImage: @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -1217,7 +1217,7 @@ class KSampler: {"model": ("MODEL",), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "positive": ("CONDITIONING", ), @@ -1243,7 +1243,7 @@ class KSamplerAdvanced: "add_noise": (["enable", "disable"], ), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "positive": ("CONDITIONING", ), @@ -1423,7 +1423,7 @@ class LoadImageMask: return True class ImageScale: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod @@ -1444,7 +1444,7 @@ class ImageScale: return (s,) class ImageScaleBy: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] @classmethod def INPUT_TYPES(s): diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..b5a68e0f1 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +markers = + inference: mark as inference test (deselect with '-m "not inference"') +testpaths = tests +addopts = -s \ No newline at end of file diff --git a/server.py b/server.py index d04060499..b2e16716b 100644 --- a/server.py +++ b/server.py @@ -132,12 +132,12 @@ class PromptServer(): @routes.get("/extensions") async def get_extensions(request): files = glob.glob(os.path.join( - self.web_root, 'extensions/**/*.js'), recursive=True) + glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True) extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)) for name, dir in nodes.EXTENSION_WEB_DIRS.items(): - files = glob.glob(os.path.join(dir, '**/*.js'), recursive=True) + files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True) extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote( name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..2005fd45b --- /dev/null +++ b/tests/README.md @@ -0,0 +1,29 @@ +# Automated Testing + +## Running tests locally + +Additional requirements for running tests: +``` +pip install pytest +pip install websocket-client==1.6.1 +opencv-python==4.6.0.66 +scikit-image==0.21.0 +``` +Run inference tests: +``` +pytest tests/inference +``` + +## Quality regression test +Compares images in 2 directories to ensure they are the same + +1) Run an inference test to save a directory of "ground truth" images +``` + pytest tests/inference --output_dir tests/inference/baseline +``` +2) Make code edits + +3) Run inference and quality comparison tests +``` +pytest +``` \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/compare/conftest.py b/tests/compare/conftest.py new file mode 100644 index 000000000..dd5078c9e --- /dev/null +++ b/tests/compare/conftest.py @@ -0,0 +1,41 @@ +import os +import pytest + +# Command line arguments for pytest +def pytest_addoption(parser): + parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images') + parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test') + parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics') + parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images') + +# This initializes args at the beginning of the test session +@pytest.fixture(scope="session", autouse=True) +def args_pytest(pytestconfig): + args = {} + args['baseline_dir'] = pytestconfig.getoption('baseline_dir') + args['test_dir'] = pytestconfig.getoption('test_dir') + args['metrics_file'] = pytestconfig.getoption('metrics_file') + args['img_output_dir'] = pytestconfig.getoption('img_output_dir') + + # Initialize metrics file + with open(args['metrics_file'], 'a') as f: + # if file is empty, write header + if os.stat(args['metrics_file']).st_size == 0: + f.write("| date | run | file | status | value | \n") + f.write("| --- | --- | --- | --- | --- | \n") + + return args + + +def gather_file_basenames(directory: str): + files = [] + for file in os.listdir(directory): + if file.endswith(".png"): + files.append(file) + return files + +# Creates the list of baseline file names to use as a fixture +def pytest_generate_tests(metafunc): + if "baseline_fname" in metafunc.fixturenames: + baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir")) + metafunc.parametrize("baseline_fname", baseline_fnames) diff --git a/tests/compare/test_quality.py b/tests/compare/test_quality.py new file mode 100644 index 000000000..92a2d5a8b --- /dev/null +++ b/tests/compare/test_quality.py @@ -0,0 +1,195 @@ +import datetime +import numpy as np +import os +from PIL import Image +import pytest +from pytest import fixture +from typing import Tuple, List + +from cv2 import imread, cvtColor, COLOR_BGR2RGB +from skimage.metrics import structural_similarity as ssim + + +""" +This test suite compares images in 2 directories by file name +The directories are specified by the command line arguments --baseline_dir and --test_dir + +""" +# ssim: Structural Similarity Index +# Returns a tuple of (ssim, diff_image) +def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]: + score, diff = ssim(img0, img1, channel_axis=-1, full=True) + # rescale the difference image to 0-255 range + diff = (diff * 255).astype("uint8") + return score, diff + +# Metrics must return a tuple of (score, diff_image) +METRICS = {"ssim": ssim_score} +METRICS_PASS_THRESHOLD = {"ssim": 0.95} + + +class TestCompareImageMetrics: + @fixture(scope="class") + def test_file_names(self, args_pytest): + test_dir = args_pytest['test_dir'] + fnames = self.gather_file_basenames(test_dir) + yield fnames + del fnames + + @fixture(scope="class", autouse=True) + def teardown(self, args_pytest): + yield + # Runs after all tests are complete + # Aggregate output files into a grid of images + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + img_output_dir = args_pytest['img_output_dir'] + metrics_file = args_pytest['metrics_file'] + + grid_dir = os.path.join(img_output_dir, "grid") + os.makedirs(grid_dir, exist_ok=True) + + for metric_dir in METRICS.keys(): + metric_path = os.path.join(img_output_dir, metric_dir) + for file in os.listdir(metric_path): + if file.endswith(".png"): + score = self.lookup_score_from_fname(file, metrics_file) + image_file_list = [] + image_file_list.append([ + os.path.join(baseline_dir, file), + os.path.join(test_dir, file), + os.path.join(metric_path, file) + ]) + # Create grid + image_list = [[Image.open(file) for file in files] for files in image_file_list] + grid = self.image_grid(image_list) + grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}")) + + # Tests run for each baseline file name + @fixture() + def fname(self, baseline_fname): + yield baseline_fname + del baseline_fname + + def test_directories_not_empty(self, args_pytest): + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty" + assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty" + + def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest): + # Check that all files in baseline_dir have a file in test_dir with matching metadata + baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname) + file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names] + file_match = self.find_file_match(baseline_file_path, file_paths) + assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}" + + # For a baseline image file, finds the corresponding file name in test_dir and + # compares the images using the metrics in METRICS + @pytest.mark.parametrize("metric", METRICS.keys()) + def test_pipeline_compare( + self, + args_pytest, + fname, + test_file_names, + metric, + ): + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + metrics_output_file = args_pytest['metrics_file'] + img_output_dir = args_pytest['img_output_dir'] + + baseline_file_path = os.path.join(baseline_dir, fname) + + # Find file match + file_paths = [os.path.join(test_dir, f) for f in test_file_names] + test_file = self.find_file_match(baseline_file_path, file_paths) + + # Run metrics + sample_baseline = self.read_img(baseline_file_path) + sample_secondary = self.read_img(test_file) + + score, metric_img = METRICS[metric](sample_baseline, sample_secondary) + metric_status = score > METRICS_PASS_THRESHOLD[metric] + + # Save metric values + with open(metrics_output_file, 'a') as f: + run_info = os.path.splitext(fname)[0] + metric_status_str = "PASS ✅" if metric_status else "FAIL ❌" + date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n") + + # Save metric image + metric_img_dir = os.path.join(img_output_dir, metric) + os.makedirs(metric_img_dir, exist_ok=True) + output_filename = f'{fname}' + Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename)) + + assert score > METRICS_PASS_THRESHOLD[metric] + + def read_img(self, filename: str) -> np.ndarray: + cvImg = imread(filename) + cvImg = cvtColor(cvImg, COLOR_BGR2RGB) + return cvImg + + def image_grid(self, img_list: list[list[Image.Image]]): + # imgs is a 2D list of images + # Assumes the input images are a rectangular grid of equal sized images + rows = len(img_list) + cols = len(img_list[0]) + + w, h = img_list[0][0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + + for i, row in enumerate(img_list): + for j, img in enumerate(row): + grid.paste(img, box=(j*w, i*h)) + return grid + + def lookup_score_from_fname(self, + fname: str, + metrics_output_file: str + ) -> float: + fname_basestr = os.path.splitext(fname)[0] + with open(metrics_output_file, 'r') as f: + for line in f: + if fname_basestr in line: + score = float(line.split('|')[5]) + return score + raise ValueError(f"Could not find score for {fname} in {metrics_output_file}") + + def gather_file_basenames(self, directory: str): + files = [] + for file in os.listdir(directory): + if file.endswith(".png"): + files.append(file) + return files + + def read_file_prompt(self, fname:str) -> str: + # Read prompt from image file metadata + img = Image.open(fname) + img.load() + return img.info['prompt'] + + def find_file_match(self, baseline_file: str, file_paths: List[str]): + # Find a file in file_paths with matching metadata to baseline_file + baseline_prompt = self.read_file_prompt(baseline_file) + + # Do not match empty prompts + if baseline_prompt is None or baseline_prompt == "": + return None + + # Find file match + # Reorder test_file_names so that the file with matching name is first + # This is an optimization because matching file names are more likely + # to have matching metadata if they were generated with the same script + basename = os.path.basename(baseline_file) + file_path_basenames = [os.path.basename(f) for f in file_paths] + if basename in file_path_basenames: + match_index = file_path_basenames.index(basename) + file_paths.insert(0, file_paths.pop(match_index)) + + for f in file_paths: + test_file_prompt = self.read_file_prompt(f) + if baseline_prompt == test_file_prompt: + return f \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..1a35880af --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,36 @@ +import os +import pytest + +# Command line arguments for pytest +def pytest_addoption(parser): + parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images') + parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") + parser.addoption("--port", type=int, default=8188, help="Set the listen port.") + +# This initializes args at the beginning of the test session +@pytest.fixture(scope="session", autouse=True) +def args_pytest(pytestconfig): + args = {} + args['output_dir'] = pytestconfig.getoption('output_dir') + args['listen'] = pytestconfig.getoption('listen') + args['port'] = pytestconfig.getoption('port') + + os.makedirs(args['output_dir'], exist_ok=True) + + return args + +def pytest_collection_modifyitems(items): + # Modifies items so tests run in the correct order + + LAST_TESTS = ['test_quality'] + + # Move the last items to the end + last_items = [] + for test_name in LAST_TESTS: + for item in items.copy(): + print(item.module.__name__, item) + if item.module.__name__ == test_name: + last_items.append(item) + items.remove(item) + + items.extend(last_items) diff --git a/tests/inference/__init__.py b/tests/inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/graphs/default_graph_sdxl1_0.json b/tests/inference/graphs/default_graph_sdxl1_0.json new file mode 100644 index 000000000..c06c6829c --- /dev/null +++ b/tests/inference/graphs/default_graph_sdxl1_0.json @@ -0,0 +1,144 @@ +{ + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage" + }, + "6": { + "inputs": { + "text": "a photo of a cat", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode" + }, + "10": { + "inputs": { + "add_noise": "enable", + "noise_seed": 42, + "steps": 20, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 0, + "end_at_step": 32, + "return_with_leftover_noise": "enable", + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "15", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSamplerAdvanced" + }, + "12": { + "inputs": { + "samples": [ + "14", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode" + }, + "13": { + "inputs": { + "filename_prefix": "test_inference", + "images": [ + "12", + 0 + ] + }, + "class_type": "SaveImage" + }, + "14": { + "inputs": { + "add_noise": "disable", + "noise_seed": 42, + "steps": 20, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 32, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": [ + "16", + 0 + ], + "positive": [ + "17", + 0 + ], + "negative": [ + "20", + 0 + ], + "latent_image": [ + "10", + 0 + ] + }, + "class_type": "KSamplerAdvanced" + }, + "15": { + "inputs": { + "conditioning": [ + "6", + 0 + ] + }, + "class_type": "ConditioningZeroOut" + }, + "16": { + "inputs": { + "ckpt_name": "sd_xl_refiner_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + }, + "17": { + "inputs": { + "text": "a photo of a cat", + "clip": [ + "16", + 1 + ] + }, + "class_type": "CLIPTextEncode" + }, + "20": { + "inputs": { + "text": "", + "clip": [ + "16", + 1 + ] + }, + "class_type": "CLIPTextEncode" + } + } \ No newline at end of file diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py new file mode 100644 index 000000000..141cc5c7e --- /dev/null +++ b/tests/inference/test_inference.py @@ -0,0 +1,239 @@ +from copy import deepcopy +from io import BytesIO +from urllib import request +import numpy +import os +from PIL import Image +import pytest +from pytest import fixture +import time +import torch +from typing import Union +import json +import subprocess +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import urllib.request +import urllib.parse + + +from comfy.samplers import KSampler + +""" +These tests generate and save images through a range of parameters +""" + +class ComfyGraph: + def __init__(self, + graph: dict, + sampler_nodes: list[str], + ): + self.graph = graph + self.sampler_nodes = sampler_nodes + + def set_prompt(self, prompt, negative_prompt=None): + # Sets the prompt for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + prompt_node = self.graph[node]['inputs']['positive'][0] + self.graph[prompt_node]['inputs']['text'] = prompt + if negative_prompt: + negative_prompt_node = self.graph[node]['inputs']['negative'][0] + self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt + + def set_sampler_name(self, sampler_name:str, ): + # sets the sampler name for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + self.graph[node]['inputs']['sampler_name'] = sampler_name + + def set_scheduler(self, scheduler:str): + # sets the sampler name for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + self.graph[node]['inputs']['scheduler'] = scheduler + + def set_filename_prefix(self, prefix:str): + # sets the filename prefix for the save nodes + for node in self.graph: + if self.graph[node]['class_type'] == 'SaveImage': + self.graph[node]['inputs']['filename_prefix'] = prefix + + +class ComfyClient: + # From examples/websockets_api_example.py + + def connect(self, + listen:str = '127.0.0.1', + port:Union[str,int] = 8188, + client_id: str = str(uuid.uuid4()) + ): + self.client_id = client_id + self.server_address = f"{listen}:{port}" + ws = websocket.WebSocket() + ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) + self.ws = ws + + def queue_prompt(self, prompt): + p = {"prompt": prompt, "client_id": self.client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + + def get_image(self, filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: + return response.read() + + def get_history(self, prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: + return json.loads(response.read()) + + def get_images(self, graph, save=True): + prompt = graph + if not save: + # Replace save nodes with preview nodes + prompt_str = json.dumps(prompt) + prompt_str = prompt_str.replace('SaveImage', 'PreviewImage') + prompt = json.loads(prompt_str) + + prompt_id = self.queue_prompt(prompt)['prompt_id'] + output_images = {} + while True: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['node'] is None and data['prompt_id'] == prompt_id: + break #Execution is done + else: + continue #previews are binary data + + history = self.get_history(prompt_id)[prompt_id] + for o in history['outputs']: + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + if 'images' in node_output: + images_output = [] + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output + + return output_images + +# +# Initialize graphs +# +default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json' +with open(default_graph_file, 'r') as file: + default_graph = json.loads(file.read()) +DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14']) +DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0] + +# +# Loop through these variables +# +comfy_graph_list = [DEFAULT_COMFY_GRAPH] +comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID] +prompt_list = [ + 'a painting of a cat', +] + +sampler_list = KSampler.SAMPLERS +scheduler_list = KSampler.SCHEDULERS + +@pytest.mark.inference +@pytest.mark.parametrize("sampler", sampler_list) +@pytest.mark.parametrize("scheduler", scheduler_list) +@pytest.mark.parametrize("prompt", prompt_list) +class TestInference: + # + # Initialize server and client + # + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): + # Start server + p = subprocess.Popen([ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + ]) + yield + p.kill() + torch.cuda.empty_cache() + + def start_client(self, listen:str, port:int): + # Start client + comfy_client = ComfyClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + comfy_client.connect(listen=listen, port=port) + except ConnectionRefusedError as e: + print(e) + print(f"({i+1}/{n_tries}) Retrying...") + else: + break + return comfy_client + + # + # Client and graph fixtures with server warmup + # + # Returns a "_client_graph", which is client-graph pair corresponding to an initialized server + # The "graph" is the default graph + @fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True) + def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph): + comfy_graph = request.param + + # Start client + comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"]) + + # Warm up pipeline + comfy_client.get_images(graph=comfy_graph.graph, save=False) + + yield comfy_client, comfy_graph + del comfy_client + del comfy_graph + torch.cuda.empty_cache() + + @fixture + def client(self, _client_graph): + client = _client_graph[0] + yield client + + @fixture + def comfy_graph(self, _client_graph): + # avoid mutating the graph + graph = deepcopy(_client_graph[1]) + yield graph + + def test_comfy( + self, + client, + comfy_graph, + sampler, + scheduler, + prompt, + request + ): + test_info = request.node.name + comfy_graph.set_filename_prefix(test_info) + # Settings for comfy graph + comfy_graph.set_sampler_name(sampler) + comfy_graph.set_scheduler(scheduler) + comfy_graph.set_prompt(prompt) + + # Generate + images = client.get_images(comfy_graph.graph) + + assert len(images) != 0, "No images generated" + # assert all images are not blank + for images_output in images.values(): + for image_data in images_output: + pil_image = Image.open(BytesIO(image_data)) + assert numpy.array(pil_image).any() != 0, "Image is blank" + + diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index f9a5b7278..606605f0a 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -142,7 +142,7 @@ app.registerExtension({ const r = origOnNodeCreated ? origOnNodeCreated.apply(this) : undefined; if (this.widgets) { for (const w of this.widgets) { - if (w?.options?.forceInput) { + if (w?.options?.forceInput || w?.options?.defaultInput) { const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}]; convertToInput(this, w, config); } diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 4a21a1b34..f81c83a8a 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -4928,7 +4928,9 @@ LGraphNode.prototype.executeAction = function(action) this.title = o.title; this._bounding.set(o.bounding); this.color = o.color; - this.font = o.font; + if (o.font_size) { + this.font_size = o.font_size; + } }; LGraphGroup.prototype.serialize = function() { @@ -4942,7 +4944,7 @@ LGraphNode.prototype.executeAction = function(action) Math.round(b[3]) ], color: this.color, - font: this.font + font_size: this.font_size }; }; diff --git a/web/scripts/app.js b/web/scripts/app.js index a3661da64..5efe08c00 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -532,7 +532,17 @@ export class ComfyApp { } } this.imageRects.push([x, y, cellWidth, cellHeight]); - ctx.drawImage(img, x, y, cellWidth, cellHeight); + + let wratio = cellWidth/img.width; + let hratio = cellHeight/img.height; + var ratio = Math.min(wratio, hratio); + + let imgHeight = ratio * img.height; + let imgY = row * cellHeight + shiftY + (cellHeight - imgHeight)/2; + let imgWidth = ratio * img.width; + let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2; + + ctx.drawImage(img, imgX, imgY, imgWidth, imgHeight); ctx.filter = "none"; } @@ -671,6 +681,10 @@ export class ComfyApp { */ #addPasteHandler() { document.addEventListener("paste", (e) => { + // ctrl+shift+v is used to paste nodes with connections + // this is handled by litegraph + if(this.shiftDown) return; + let data = (e.clipboardData || window.clipboardData); const items = data.items; @@ -735,9 +749,18 @@ export class ComfyApp { */ #addCopyHandler() { document.addEventListener("copy", (e) => { - // copy - if (this.canvas.selected_nodes) { - this.canvas.copyToClipboard(); + if (e.target.type === "text" || e.target.type === "textarea") { + // Default system copy + return; + } + + // copy nodes and clear clipboard + if (e.target.className === "litegraph" && this.canvas.selected_nodes) { + this.canvas.copyToClipboard(); + e.clipboardData.setData('text', ' '); //clearData doesn't remove images from clipboard + e.preventDefault(); + e.stopImmediatePropagation(); + return false; } }); } @@ -840,24 +863,14 @@ export class ComfyApp { // Ctrl+C Copy if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) { - if (e.shiftKey) { - this.copyToClipboard(true); - block_default = true; - } - // Trigger default onCopy + // Trigger onCopy return true; } // Ctrl+V Paste - if ((e.key === 'v') && (e.metaKey || e.ctrlKey)) { - if (e.shiftKey) { - this.pasteFromClipboard(true); - block_default = true; - } - else { - // Trigger default onPaste - return true; - } + if ((e.key === 'v' || e.key == 'V') && (e.metaKey || e.ctrlKey) && !e.shiftKey) { + // Trigger onPaste + return true; } } @@ -1248,6 +1261,10 @@ export class ComfyApp { if (!config.widget.options) config.widget.options = {}; config.widget.options.forceInput = inputData[1].forceInput; } + if(widgetCreated && inputData[1]?.defaultInput && config?.widget) { + if (!config.widget.options) config.widget.options = {}; + config.widget.options.defaultInput = inputData[1].defaultInput; + } } for (const o in nodeData["output"]) { @@ -1291,7 +1308,13 @@ export class ComfyApp { let reset_invalid_values = false; if (!graphData) { - graphData = structuredClone(defaultGraph); + if (typeof structuredClone === "undefined") + { + graphData = JSON.parse(JSON.stringify(defaultGraph)); + }else + { + graphData = structuredClone(defaultGraph); + } reset_invalid_values = true; } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index f39939bf3..1e7920167 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -577,6 +577,25 @@ export class ComfyUI { defaultValue: false, }); + this.settings.addSetting({ + id: "Comfy.DisableFloatRounding", + name: "Disable rounding floats (requires page reload).", + type: "boolean", + defaultValue: false, + }); + + this.settings.addSetting({ + id: "Comfy.FloatRoundingPrecision", + name: "Decimal places [0 = auto] (requires page reload).", + type: "slider", + attrs: { + min: 0, + max: 6, + step: 1, + }, + defaultValue: 0, + }); + const fileInput = $el("input", { id: "comfy-file-input", type: "file", diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 30caa6a8c..2b0239374 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,18 +1,23 @@ import { api } from "./api.js" -function getNumberDefaults(inputData, defaultStep) { +function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { let defaultVal = inputData[1]["default"]; - let { min, max, step } = inputData[1]; + let { min, max, step, round} = inputData[1]; if (defaultVal == undefined) defaultVal = 0; if (min == undefined) min = 0; if (max == undefined) max = 2048; if (step == undefined) step = defaultStep; -// precision is the number of decimal places to show. -// by default, display the the smallest number of decimal places such that changes of size step are visible. - let precision = Math.max(-Math.floor(Math.log10(step)),0) -// by default, round the value to those decimal places shown. - let round = Math.round(1000000*Math.pow(0.1,precision))/1000000; + // precision is the number of decimal places to show. + // by default, display the the smallest number of decimal places such that changes of size step are visible. + if (precision == undefined) { + precision = Math.max(-Math.floor(Math.log10(step)),0); + } + + if (enable_rounding && (round == undefined || round === true)) { + // by default, round the value to those decimal places shown. + round = Math.round(1000000*Math.pow(0.1,precision))/1000000; + } return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } }; } @@ -268,15 +273,22 @@ export const ComfyWidgets = { "INT:noise_seed": seedWidget, FLOAT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 0.5); + let precision = app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision"); + let disable_rounding = app.ui.settings.getSettingValue("Comfy.DisableFloatRounding") + if (precision == 0) precision = undefined; + const { val, config } = getNumberDefaults(inputData, 0.5, precision, !disable_rounding); return { widget: node.addWidget(widgetType, inputName, val, function (v) { - this.value = Math.round(v/config.round)*config.round; + if (config.round) { + this.value = Math.round(v/config.round)*config.round; + } else { + this.value = v; + } }, config) }; }, INT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 1); + const { val, config } = getNumberDefaults(inputData, 1, 0, true); Object.assign(config, { precision: 0 }); return { widget: node.addWidget(