diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index c7843d402..c7ef93ce1 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -54,7 +54,7 @@ jobs: cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z cd ComfyUI_windows_portable_nightly_pytorch diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 260a51bb2..bef1868b9 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -46,6 +46,10 @@ fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") +fpvae_group = parser.add_mutually_exclusive_group() +fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") +fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.") + parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") class LatentPreviewMethod(enum.Enum): diff --git a/comfy/model_management.py b/comfy/model_management.py index a918a81f6..09dcaa295 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -366,6 +366,14 @@ def vae_offload_device(): else: return torch.device("cpu") +def vae_dtype(): + if args.fp16_vae: + return torch.float16 + elif args.bf16_vae: + return torch.bfloat16 + else: + return torch.float32 + def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type diff --git a/comfy/samplers.py b/comfy/samplers.py index b5f79c058..81d1facd8 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -374,7 +374,7 @@ def resolve_cond_masks(conditions, h, w, device): modified = c[1].copy() if len(mask.shape) == 2: mask = mask.unsqueeze(0) - if mask.shape[2] != h or mask.shape[3] != w: + if mask.shape[1] != h or mask.shape[2] != w: mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) if modified.get("set_area_to_bounds", False): diff --git a/comfy/sd.py b/comfy/sd.py index 7e64536c1..76eaa5b59 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -505,6 +505,8 @@ class VAE: device = model_management.vae_device() self.device = device self.offload_device = model_management.vae_offload_device() + self.vae_dtype = model_management.vae_dtype() + self.first_stage_model.to(self.vae_dtype) def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) @@ -512,7 +514,7 @@ class VAE: steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) - decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.device)) + 1.0) + decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() output = torch.clamp(( (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + @@ -526,7 +528,7 @@ class VAE: steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() + encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.vae_dtype).to(self.device) - 1.).sample().float() samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) @@ -543,8 +545,8 @@ class VAE: pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x+batch_number].to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu() + samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) + pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu().float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) @@ -570,8 +572,8 @@ class VAE: batch_number = max(1, batch_number) samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() + pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index 9ee23c752..dcf8859fa 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -41,6 +41,12 @@ class CLIPTextEncodeSDXL: def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l): tokens = clip.tokenize(text_g) tokens["l"] = clip.tokenize(text_l)["l"] + if len(tokens["l"]) != len(tokens["g"]): + empty = clip.tokenize("") + while len(tokens["l"]) < len(tokens["g"]): + tokens["l"] += empty["l"] + while len(tokens["l"]) > len(tokens["g"]): + tokens["g"] += empty["g"] cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], ) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index e72f81f28..36078fffc 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -13,6 +13,7 @@ a111: models/LyCORIS upscale_models: | models/ESRGAN + models/RealESRGAN models/SwinIR embeddings: embeddings hypernetworks: models/hypernetworks