Fix Pylint warnings

This commit is contained in:
doctorpangloss 2024-11-18 14:10:58 -08:00
parent fb7a3f9386
commit 264d84db39
4 changed files with 13 additions and 13 deletions

View File

@ -45,7 +45,7 @@ class InternalRoutes:
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in logger.get_logs()]))
@self.routes.get('/logs/raw')
async def get_logs(request):
async def get_logs_raw(request):
self.terminal_service.update_size()
return web.json_response({
"entries": list(logger.get_logs()),

View File

@ -463,13 +463,13 @@ class MMDiT(nn.Module):
if len(self.double_layers) > 0:
for i, layer in enumerate(self.double_layers):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
def block_wrap_1(args):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap_1})
c = out["txt"]
x = out["img"]
else:
@ -480,12 +480,12 @@ class MMDiT(nn.Module):
cx = torch.cat([c, x], dim=1)
for i, layer in enumerate(self.single_layers):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
def block_wrap_2(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap_2})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs)

View File

@ -120,12 +120,12 @@ class Flux(nn.Module):
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
def block_wrap_1(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap_1})
txt = out["txt"]
img = out["img"]
else:
@ -142,12 +142,12 @@ class Flux(nn.Module):
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
def block_wrap_2(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap_2})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe)

View File

@ -1057,15 +1057,15 @@ def reshape_mask(input_mask, output_shape):
if dims == 1:
scale_mode = "linear"
if dims == 2:
elif dims == 2:
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
scale_mode = "bilinear"
if dims == 3:
elif dims == 3:
if len(input_mask.shape) < 5:
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
scale_mode = "trilinear"
else:
raise ValueError(f"invalid dims={dims}")
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]: