diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index 2d7af8e26..37a187024 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -362,11 +362,11 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height full_output_folder = str(os.path.join(output_dir, subfolder)) - if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: - err = "**** ERROR: Saving image outside the output folder is not allowed." + \ - "\n full_output_folder: " + os.path.abspath(full_output_folder) + \ - "\n output_dir: " + output_dir + \ - "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) + if str(os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))) != str(output_dir): + err = f"""**** ERROR: Saving image outside the output folder is not allowed. + full_output_folder: {os.path.abspath(full_output_folder)} + output_dir: {output_dir} + commonpath: {os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))}""" logging.error(err) raise Exception(err) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 20b968f0a..a99252ba7 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -786,10 +786,10 @@ class PromptServer(ExecutorToClientProgress): msg = await self.messages.get() await self.send(*msg) - async def start(self, address, port, verbose=True, call_on_start=None): + async def start(self, address: str | None, port: int | None, verbose=True, call_on_start=None): runner = web.AppRunner(self.app, access_log=None) await runner.setup() - site = web.TCPSite(runner, address, port) + site = web.TCPSite(runner, host=address, port=port) await site.start() if verbose: diff --git a/comfy/component_model/file_output_path.py b/comfy/component_model/file_output_path.py index 0b35f3379..ed2182706 100644 --- a/comfy/component_model/file_output_path.py +++ b/comfy/component_model/file_output_path.py @@ -30,13 +30,13 @@ def file_output_path(filename: str, type: Literal["input", "output", "temp"] = " if output_dir is None: raise ValueError(f"no such output directory because invalid type specified (type={type})") if subfolder is not None and subfolder != "": - full_output_dir = os.path.join(output_dir, subfolder) - if os.path.commonpath([os.path.abspath(full_output_dir), output_dir]) != output_dir: + full_output_dir = str(os.path.join(output_dir, subfolder)) + if str(os.path.commonpath([os.path.abspath(full_output_dir), output_dir])) != str(output_dir): raise PermissionError("insecure") output_dir = full_output_dir filename = os.path.basename(filename) else: - if os.path.commonpath([os.path.abspath(output_dir), os.path.join(output_dir, filename)]) != output_dir: + if str(os.path.commonpath([os.path.abspath(output_dir), os.path.join(output_dir, filename)])) != str(output_dir): raise PermissionError("insecure") file = os.path.join(output_dir, filename) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 4ca466d9a..69192bc62 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -2,6 +2,7 @@ import torch class LatentFormat: scale_factor = 1.0 + latent_channels = 4 latent_rgb_factors = None taesd_decoder_name = None @@ -72,6 +73,7 @@ class SD_X4(LatentFormat): ] class SC_Prior(LatentFormat): + latent_channels = 16 def __init__(self): self.scale_factor = 1.0 self.latent_rgb_factors = [ diff --git a/comfy/model_management.py b/comfy/model_management.py index 96dc027e8..e2381b90b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -6,7 +6,7 @@ import platform import warnings from enum import Enum from threading import RLock -from typing import Literal +from typing import Literal, List import psutil import torch @@ -300,7 +300,7 @@ except: logging.info("VAE dtype: {}".format(VAE_DTYPE)) -current_loaded_models = [] +current_loaded_models: List["LoadedModel"] = [] def module_size(module): @@ -318,6 +318,7 @@ class LoadedModel: self.device = model.load_device self.weights_loaded = False self.real_model = None + self.currently_used = True def model_memory(self): return self.model.model_size() @@ -412,6 +413,7 @@ def free_memory(memory_required, device, keep_loaded=[]): if shift_model.device == device: if shift_model not in keep_loaded: can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) + shift_model.currently_used = False for x in sorted(can_unload): i = x[-1] @@ -458,6 +460,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False): current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) loaded = None else: + loaded.currently_used = True models_already_loaded.append(loaded) if loaded is None: if hasattr(x, "model"): @@ -515,6 +518,16 @@ def load_model_gpu(model): with model_management_lock: return load_models_gpu([model]) +def loaded_models(only_currently_used=False): + with model_management_lock: + output = [] + for m in current_loaded_models: + if only_currently_used: + if not m.currently_used: + continue + + output.append(m.model) + return output def cleanup_models(keep_clone_weights_loaded=False): with model_management_lock: @@ -763,6 +776,8 @@ def pytorch_attention_flash_attention(): # TODO: more reliable way of checking for flash attention? if is_nvidia(): # pytorch flash attention only works on Nvidia return True + if is_intel_xpu(): + return True return False def force_upcast_attention_dtype(): diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index b9a84dad3..11c1114f2 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -487,7 +487,7 @@ class CheckpointLoader: CATEGORY = "advanced/loaders" - def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): + def load_checkpoint(self, config_name, ckpt_name): config_path = folder_paths.get_full_path("configs", config_name) ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS) return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) @@ -502,7 +502,7 @@ class CheckpointLoaderSimple: CATEGORY = "loaders" - def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): + def load_checkpoint(self, ckpt_name): ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS) out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out[:3] @@ -1300,6 +1300,8 @@ class SetLatentNoiseMask: def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): latent_image = latent["samples"] + latent_image = sample.fix_empty_latent_channels(model, latent_image) + if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: diff --git a/comfy/sample.py b/comfy/sample.py index 7a97e54be..e021385d8 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -26,6 +26,12 @@ def prepare_noise(latent_image, seed, noise_inds=None): noises = torch.cat(noises, axis=0) return noises +def fix_empty_latent_channels(model, latent_image): + latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels + if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1) + return latent_image + def prepare_sampling(model, noise_shape, positive, negative, noise_mask): logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed") return model, positive, negative, noise_mask, [] diff --git a/comfy/samplers.py b/comfy/samplers.py index f38407862..9d5f34547 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -11,7 +11,8 @@ from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES def get_area_and_mult(conds, x_in, timestep_in): - area = (x_in.shape[2], x_in.shape[3], 0, 0) + dims = tuple(x_in.shape[2:]) + area = None strength = 1.0 if 'timestep_start' in conds: @@ -23,11 +24,16 @@ def get_area_and_mult(conds, x_in, timestep_in): if timestep_in[0] < timestep_end: return None if 'area' in conds: - area = conds['area'] + area = list(conds['area']) if 'strength' in conds: strength = conds['strength'] - input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + input_x = x_in + if area is not None: + for i in range(len(dims)): + area[i] = min(input_x.shape[i + 2] - area[len(dims) + i], area[i]) + input_x = input_x.narrow(i + 2, area[len(dims) + i], area[i]) + if 'mask' in conds: # Scale the mask to the size of the input # The mask should have been resized as we began the sampling process @@ -35,28 +41,30 @@ def get_area_and_mult(conds, x_in, timestep_in): if "mask_strength" in conds: mask_strength = conds["mask_strength"] mask = conds['mask'] - assert(mask.shape[1] == x_in.shape[2]) - assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:input_x.shape[0],area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength + assert(mask.shape[1:] == x_in.shape[2:]) + + mask = mask[:input_x.shape[0]] + if area is not None: + for i in range(len(dims)): + mask = mask.narrow(i + 1, area[len(dims) + i], area[i]) + + mask = mask * mask_strength mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) mult = mask * strength - if 'mask' not in conds: + if 'mask' not in conds and area is not None: rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + for i in range(len(dims)): + if area[len(dims) + i] != 0: + for t in range(rr): + m = mult.narrow(i + 2, t, 1) + m *= ((1.0/rr) * (t + 1)) + if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]: + for t in range(rr): + m = mult.narrow(i + 2, area[i] - 1 - t, 1) + m *= ((1.0/rr) * (t + 1)) conditioning = {} model_conds = conds["model_conds"] @@ -222,8 +230,19 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options): for o in range(batch_chunks): cond_index = cond_or_uncond[o] - out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] for i in range(len(out_conds)): out_conds[i] /= out_counts[i] @@ -338,7 +357,7 @@ def get_mask_aabb(masks): return bounding_boxes, is_empty -def resolve_areas_and_cond_masks(conditions, h, w, device): +def resolve_areas_and_cond_masks_multidim(conditions, dims, device): # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. # While we're doing this, we can also resolve the mask device and scaling for performance reasons for i in range(len(conditions)): @@ -347,7 +366,14 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): area = c['area'] if area[0] == "percentage": modified = c.copy() - area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w)) + a = area[1:] + a_len = len(a) // 2 + area = () + for d in range(len(dims)): + area += (max(1, round(a[d] * dims[d])),) + for d in range(len(dims)): + area += (round(a[d + a_len] * dims[d]),) + modified['area'] = area c = modified conditions[i] = c @@ -356,12 +382,12 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): mask = c['mask'] mask = mask.to(device=device) modified = c.copy() - if len(mask.shape) == 2: + if len(mask.shape) == len(dims): mask = mask.unsqueeze(0) - if mask.shape[1] != h or mask.shape[2] != w: - mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) + if mask.shape[1:] != dims: + mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1) - if modified.get("set_area_to_bounds", False): + if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2 bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) boxes, is_empty = get_mask_aabb(bounds) if is_empty[0]: @@ -378,7 +404,11 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): modified['mask'] = mask conditions[i] = modified -def create_cond_with_same_area_if_none(conds, c): +def resolve_areas_and_cond_masks(conditions, h, w, device): + logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.") + return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device) + +def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2 if 'area' not in c: return @@ -482,7 +512,10 @@ def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwar params = x.copy() params["device"] = device params["noise"] = noise - params["width"] = params.get("width", noise.shape[3] * 8) + default_width = None + if len(noise.shape) >= 4: #TODO: 8 multiple should be set by the model + default_width = noise.shape[3] * 8 + params["width"] = params.get("width", default_width) params["height"] = params.get("height", noise.shape[2] * 8) params["prompt_type"] = params.get("prompt_type", prompt_type) for k in kwargs: @@ -570,7 +603,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): for k in conds: conds[k] = conds[k][:] - resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device) + resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) for k in conds: calculate_start_end_timesteps(model, conds[k]) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 8f8f39ea4..5665b129f 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -77,7 +77,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel, - special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32 + special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, + return_projected_pooled=True): # clip-vit-base-patch32 super().__init__() if special_tokens is None: special_tokens = {"start": 49406, "end": 49407, "pad": 49407} @@ -96,6 +97,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.enable_attention_masks = enable_attention_masks + self.zero_out_masked = zero_out_masked self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled @@ -174,20 +176,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): attention_mask = None if self.enable_attention_masks: attention_mask = torch.zeros_like(tokens) - max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + end_token = self.special_tokens.get("end", -1) for x in range(attention_mask.shape[0]): for y in range(attention_mask.shape[1]): attention_mask[x, y] = 1 - if tokens[x, y] == max_token: + if tokens[x, y] == end_token: break outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": - z = outputs[0] + z = outputs[0].float() else: - z = outputs[1] + z = outputs[1].float() + + if self.zero_out_masked and attention_mask is not None: + z *= attention_mask.unsqueeze(-1).float() pooled_output = None if len(outputs) >= 3: @@ -196,7 +201,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): elif outputs[2] is not None: pooled_output = outputs[2].float() - return z.float(), pooled_output + return z, pooled_output def encode(self, tokens): return self(tokens) diff --git a/comfy/utils.py b/comfy/utils.py index 92e36e343..e75d0412e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -288,11 +288,11 @@ def unet_to_diffusers(unet_config): return diffusers_unet_map -def repeat_to_batch_size(tensor, batch_size): - if tensor.shape[0] > batch_size: - return tensor[:batch_size] - elif tensor.shape[0] < batch_size: - return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] +def repeat_to_batch_size(tensor, batch_size, dim=0): + if tensor.shape[dim] > batch_size: + return tensor.narrow(dim, 0, batch_size) + elif tensor.shape[dim] < batch_size: + return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size) return tensor diff --git a/comfy/web/scripts/app.js b/comfy/web/scripts/app.js index 4dc011b9f..f96d197a8 100644 --- a/comfy/web/scripts/app.js +++ b/comfy/web/scripts/app.js @@ -1800,7 +1800,7 @@ export class ComfyApp { * @param {*} graphData A serialized graph object * @param { boolean } clean If the graph state, e.g. images, should be cleared */ - async loadGraphData(graphData, clean = true) { + async loadGraphData(graphData, clean = true, restore_view = true) { if (clean !== false) { this.clean(); } @@ -1836,7 +1836,7 @@ export class ComfyApp { try { this.graph.configure(graphData); - if (this.enableWorkflowViewRestore.value && graphData.extra?.ds) { + if (restore_view && this.enableWorkflowViewRestore.value && graphData.extra?.ds) { this.canvas.ds.offset = graphData.extra.ds.offset; this.canvas.ds.scale = graphData.extra.ds.scale; } diff --git a/comfy/web/scripts/ui.js b/comfy/web/scripts/ui.js index 36fed3238..72e43d357 100644 --- a/comfy/web/scripts/ui.js +++ b/comfy/web/scripts/ui.js @@ -228,7 +228,7 @@ class ComfyList { $el("button", { textContent: "Load", onclick: async () => { - await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow); + await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow, true, false); if (item.outputs) { app.nodeOutputs = item.outputs; } diff --git a/comfy_extras/nodes/nodes_compositing.py b/comfy_extras/nodes/nodes_compositing.py index 181b36ed6..48fe5e3dd 100644 --- a/comfy_extras/nodes/nodes_compositing.py +++ b/comfy_extras/nodes/nodes_compositing.py @@ -28,6 +28,14 @@ class PorterDuffMode(Enum): def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode): + # convert mask to alpha + src_alpha = 1 - src_alpha + dst_alpha = 1 - dst_alpha + # premultiply alpha + src_image = src_image * src_alpha + dst_image = dst_image * dst_alpha + + # composite ops below assume alpha-premultiplied images if mode == PorterDuffMode.ADD: out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1) out_image = torch.clamp(src_image + dst_image, 0, 1) @@ -35,7 +43,7 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ out_alpha = torch.zeros_like(dst_alpha) out_image = torch.zeros_like(dst_image) elif mode == PorterDuffMode.DARKEN: - out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image) elif mode == PorterDuffMode.DST: out_alpha = dst_alpha @@ -84,8 +92,13 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image else: - out_alpha = None - out_image = None + return None, None + + # back to non-premultiplied alpha + out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image)) + out_image = torch.clamp(out_image, 0, 1) + # convert alpha to mask + out_alpha = 1 - out_alpha return out_image, out_alpha diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index 7baaf5307..5e8c897e3 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -405,6 +405,7 @@ class SamplerCustom: def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): latent = latent_image latent_image = latent["samples"] + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) if not add_noise: noise = Noise_EmptyNoise().generate_noise(latent) else: @@ -563,6 +564,7 @@ class SamplerCustomAdvanced: def sample(self, noise, guider, sampler, sigmas, latent_image): latent = latent_image latent_image = latent["samples"] + latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) noise_mask = None if "noise_mask" in latent: diff --git a/tests/conftest.py b/tests/conftest.py index b88faf344..08a40b9ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import json import multiprocessing import pathlib import time @@ -9,7 +10,7 @@ import pytest from comfy.cli_args_types import Configuration -def run_server(server_arguments: dict): +def run_server(server_arguments: Configuration): from comfy.cmd.main import main from comfy.cli_args import args import asyncio @@ -18,18 +19,18 @@ def run_server(server_arguments: dict): asyncio.run(main()) -@pytest.fixture(scope="module", autouse=False) -def comfy_background_server(use_temporary_output_directory, use_temporary_input_directory) -> Tuple[Configuration, multiprocessing.Process]: +@pytest.fixture(scope="function", autouse=False) +def comfy_background_server(tmp_path) -> Tuple[Configuration, multiprocessing.Process]: import torch # Start server configuration = Configuration() - configuration.listen = True - configuration.output_directory = str(use_temporary_output_directory) - configuration.input_directory = str(use_temporary_input_directory) + configuration.listen = "localhost" + configuration.output_directory = str(tmp_path) + configuration.input_directory = str(tmp_path) - p = multiprocessing.Process(target=run_server, args=(configuration,)) - p.start() + server_process = multiprocessing.Process(target=run_server, args=(configuration,)) + server_process.start() # wait for http url to be ready success = False for i in range(60): @@ -43,8 +44,8 @@ def comfy_background_server(use_temporary_output_directory, use_temporary_input_ time.sleep(1) if not success: raise Exception("Failed to start background server") - yield configuration, p - p.terminate() + yield configuration, server_process + server_process.terminate() torch.cuda.empty_cache()