From 264d84db39cd59d3e84dbac06c61bccdcde4bfd2 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Mon, 18 Nov 2024 14:10:58 -0800 Subject: [PATCH] Fix Pylint warnings --- comfy/api_server/routes/internal/internal_routes.py | 2 +- comfy/ldm/aura/mmdit.py | 8 ++++---- comfy/ldm/flux/model.py | 8 ++++---- comfy/utils.py | 8 ++++---- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/comfy/api_server/routes/internal/internal_routes.py b/comfy/api_server/routes/internal/internal_routes.py index 6c353156d..a9082cce0 100644 --- a/comfy/api_server/routes/internal/internal_routes.py +++ b/comfy/api_server/routes/internal/internal_routes.py @@ -45,7 +45,7 @@ class InternalRoutes: return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in logger.get_logs()])) @self.routes.get('/logs/raw') - async def get_logs(request): + async def get_logs_raw(request): self.terminal_service.update_size() return web.json_response({ "entries": list(logger.get_logs()), diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index a1c498a08..50ab902d7 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -463,13 +463,13 @@ class MMDiT(nn.Module): if len(self.double_layers) > 0: for i, layer in enumerate(self.double_layers): if ("double_block", i) in blocks_replace: - def block_wrap(args): + def block_wrap_1(args): out = {} out["txt"], out["img"] = layer(args["txt"], args["img"], args["vec"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap_1}) c = out["txt"] x = out["img"] else: @@ -480,12 +480,12 @@ class MMDiT(nn.Module): cx = torch.cat([c, x], dim=1) for i, layer in enumerate(self.single_layers): if ("single_block", i) in blocks_replace: - def block_wrap(args): + def block_wrap_2(args): out = {} out["img"] = layer(args["img"], args["vec"]) return out - out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap_2}) cx = out["img"] else: cx = layer(cx, global_cond, **kwargs) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 87a66acfa..286aa398d 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -120,12 +120,12 @@ class Flux(nn.Module): blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): if ("double_block", i) in blocks_replace: - def block_wrap(args): + def block_wrap_1(args): out = {} out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"]) return out - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap_1}) txt = out["txt"] img = out["img"] else: @@ -142,12 +142,12 @@ class Flux(nn.Module): for i, block in enumerate(self.single_blocks): if ("single_block", i) in blocks_replace: - def block_wrap(args): + def block_wrap_2(args): out = {} out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"]) return out - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap_2}) img = out["img"] else: img = block(img, vec=vec, pe=pe) diff --git a/comfy/utils.py b/comfy/utils.py index 25a3b52f7..19282080b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1057,15 +1057,15 @@ def reshape_mask(input_mask, output_shape): if dims == 1: scale_mode = "linear" - - if dims == 2: + elif dims == 2: input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1])) scale_mode = "bilinear" - - if dims == 3: + elif dims == 3: if len(input_mask.shape) < 5: input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1])) scale_mode = "trilinear" + else: + raise ValueError(f"invalid dims={dims}") mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode) if mask.shape[1] < output_shape[1]: