diff --git a/comfy/sd.py b/comfy/sd.py index 3747f53b8..d898d0197 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -544,6 +544,19 @@ class VAE: / 3.0) / 2.0, min=0.0, max=1.0) return output + def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): + steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) + pbar = utils.ProgressBar(steps) + + encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor + samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples /= 3.0 + return samples + def decode(self, samples_in): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) @@ -574,28 +587,29 @@ class VAE: def encode(self, pixel_samples): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) - pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor + pixel_samples = pixel_samples.movedim(-1,1) + try: + free_memory = model_management.get_free_memory(self.device) + batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + batch_number = max(1, batch_number) + samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") + for x in range(0, pixel_samples.shape[0], batch_number): + pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device) + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor + + except model_management.OOM_EXCEPTION as e: + print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + samples = self.encode_tiled_(pixel_samples) + self.first_stage_model = self.first_stage_model.cpu() - samples = samples.cpu() return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) - pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - - steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) - steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) - steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) - pbar = utils.ProgressBar(steps) - - samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples /= 3.0 + pixel_samples = pixel_samples.movedim(-1,1) + samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) self.first_stage_model = self.first_stage_model.cpu() - samples = samples.cpu() return samples def broadcast_image_to(tensor, target_batch_size, batched_number): diff --git a/nodes.py b/nodes.py index 14c4f6dae..6ae911fdd 100644 --- a/nodes.py +++ b/nodes.py @@ -1315,6 +1315,26 @@ class ImageScale: s = s.movedim(1,-1) return (s,) +class ImageScaleBy: + upscale_methods = ["nearest-exact", "bilinear", "area"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + + CATEGORY = "image/upscaling" + + def upscale(self, image, upscale_method, scale_by): + samples = image.movedim(-1,1) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = s.movedim(1,-1) + return (s,) + class ImageInvert: @classmethod @@ -1413,6 +1433,7 @@ NODE_CLASS_MAPPINGS = { "LoadImage": LoadImage, "LoadImageMask": LoadImageMask, "ImageScale": ImageScale, + "ImageScaleBy": ImageScaleBy, "ImageInvert": ImageInvert, "ImagePadForOutpaint": ImagePadForOutpaint, "ConditioningAverage ": ConditioningAverage , @@ -1495,6 +1516,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadImage": "Load Image", "LoadImageMask": "Load Image (as Mask)", "ImageScale": "Upscale Image", + "ImageScaleBy": "Upscale Image By", "ImageUpscaleWithModel": "Upscale Image (using Model)", "ImageInvert": "Invert Image", "ImagePadForOutpaint": "Pad Image for Outpainting",