mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
device
This commit is contained in:
parent
846a019fc2
commit
5884783f90
@ -183,7 +183,7 @@ class Chroma(nn.Module):
|
|||||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
|
||||||
if i not in self.skip_mmdit:
|
if i not in self.skip_mmdit:
|
||||||
double_mod = (
|
double_mod = (
|
||||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||||
@ -229,7 +229,7 @@ class Chroma(nn.Module):
|
|||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
|
||||||
if i not in self.skip_dit:
|
if i not in self.skip_dit:
|
||||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
|
|||||||
@ -186,7 +186,7 @@ class Flux(nn.Module):
|
|||||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -233,7 +233,7 @@ class Flux(nn.Module):
|
|||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -392,7 +392,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -417,7 +417,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -375,7 +375,7 @@ class Kandinsky5(nn.Module):
|
|||||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.visual_transformer_blocks):
|
for i, block in enumerate(self.visual_transformer_blocks):
|
||||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=visual_embed.device)
|
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||||
|
|||||||
@ -921,7 +921,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device)
|
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
|
||||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
@ -1067,7 +1067,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
for it_, (block, mix_block) in enumerate(
|
for it_, (block, mix_block) in enumerate(
|
||||||
zip(self.transformer_blocks, self.time_stack)
|
zip(self.transformer_blocks, self.time_stack)
|
||||||
):
|
):
|
||||||
transformer_options["block_index"] = torch.tensor(it_, dtype=torch.uint8, device=x.device)
|
transformer_options["block_index"] = torch.tensor(it_, dtype=torch.uint8, device="cpu")
|
||||||
x = block(
|
x = block(
|
||||||
x,
|
x,
|
||||||
context=spatial_context,
|
context=spatial_context,
|
||||||
|
|||||||
@ -442,7 +442,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=hidden_states.device)
|
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device="cpu")
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user