mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
tensor
This commit is contained in:
parent
3b0368aa34
commit
846a019fc2
@ -183,7 +183,7 @@ class Chroma(nn.Module):
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
@ -229,7 +229,7 @@ class Chroma(nn.Module):
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@ -186,7 +186,7 @@ class Flux(nn.Module):
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -233,7 +233,7 @@ class Flux(nn.Module):
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@ -392,7 +392,7 @@ class HunyuanVideo(nn.Module):
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -417,7 +417,7 @@ class HunyuanVideo(nn.Module):
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=img.device)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@ -375,7 +375,7 @@ class Kandinsky5(nn.Module):
|
||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.visual_transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=visual_embed.device)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||
|
||||
@ -921,7 +921,7 @@ class SpatialTransformer(nn.Module):
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=x.device)
|
||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
@ -1067,7 +1067,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
for it_, (block, mix_block) in enumerate(
|
||||
zip(self.transformer_blocks, self.time_stack)
|
||||
):
|
||||
transformer_options["block_index"] = it_
|
||||
transformer_options["block_index"] = torch.tensor(it_, dtype=torch.uint8, device=x.device)
|
||||
x = block(
|
||||
x,
|
||||
context=spatial_context,
|
||||
|
||||
@ -442,7 +442,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
transformer_options["block_index"] = torch.tensor(i, dtype=torch.uint8, device=hidden_states.device)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user