mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Lower vram usage for flux 2 text encoder. (#10887)
This commit is contained in:
parent
18b79acba9
commit
d196a905bb
@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
|
||||||
|
|
||||||
if textmodel_json_config is None:
|
if textmodel_json_config is None:
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
@ -164,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
layer_idx = options.get("layer", self.layer_idx)
|
layer_idx = options.get("layer", self.layer_idx)
|
||||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||||
if self.layer == "all":
|
if isinstance(self.layer, list) or self.layer == "all":
|
||||||
pass
|
pass
|
||||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||||
self.layer = "last"
|
self.layer = "last"
|
||||||
@ -266,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks:
|
||||||
attention_mask_model = attention_mask
|
attention_mask_model = attention_mask
|
||||||
|
|
||||||
if self.layer == "all":
|
if isinstance(self.layer, list):
|
||||||
|
intermediate_output = self.layer
|
||||||
|
elif self.layer == "all":
|
||||||
intermediate_output = "all"
|
intermediate_output = "all"
|
||||||
else:
|
else:
|
||||||
intermediate_output = self.layer_idx
|
intermediate_output = self.layer_idx
|
||||||
|
|||||||
@ -138,7 +138,7 @@ class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
class Mistral3_24BModel(sd1_clip.SDClipModel):
|
class Mistral3_24BModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
textmodel_json_config = {}
|
textmodel_json_config = {}
|
||||||
num_layers = model_options.get("num_layers", None)
|
num_layers = model_options.get("num_layers", None)
|
||||||
if num_layers is not None:
|
if num_layers is not None:
|
||||||
@ -154,7 +154,7 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
|
|||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
|
|
||||||
out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1)
|
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
|
||||||
out = out.movedim(1, 2)
|
out = out.movedim(1, 2)
|
||||||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||||
return out, pooled, extra
|
return out, pooled, extra
|
||||||
|
|||||||
@ -434,8 +434,12 @@ class Llama2_(nn.Module):
|
|||||||
|
|
||||||
intermediate = None
|
intermediate = None
|
||||||
all_intermediate = None
|
all_intermediate = None
|
||||||
|
only_layers = None
|
||||||
if intermediate_output is not None:
|
if intermediate_output is not None:
|
||||||
if intermediate_output == "all":
|
if isinstance(intermediate_output, list):
|
||||||
|
all_intermediate = []
|
||||||
|
only_layers = set(intermediate_output)
|
||||||
|
elif intermediate_output == "all":
|
||||||
all_intermediate = []
|
all_intermediate = []
|
||||||
intermediate_output = None
|
intermediate_output = None
|
||||||
elif intermediate_output < 0:
|
elif intermediate_output < 0:
|
||||||
@ -443,7 +447,8 @@ class Llama2_(nn.Module):
|
|||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
if all_intermediate is not None:
|
if all_intermediate is not None:
|
||||||
all_intermediate.append(x.unsqueeze(1).clone())
|
if only_layers is None or (i in only_layers):
|
||||||
|
all_intermediate.append(x.unsqueeze(1).clone())
|
||||||
x = layer(
|
x = layer(
|
||||||
x=x,
|
x=x,
|
||||||
attention_mask=mask,
|
attention_mask=mask,
|
||||||
@ -457,7 +462,8 @@ class Llama2_(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
if all_intermediate is not None:
|
if all_intermediate is not None:
|
||||||
all_intermediate.append(x.unsqueeze(1).clone())
|
if only_layers is None or ((i + 1) in only_layers):
|
||||||
|
all_intermediate.append(x.unsqueeze(1).clone())
|
||||||
|
|
||||||
if all_intermediate is not None:
|
if all_intermediate is not None:
|
||||||
intermediate = torch.cat(all_intermediate, dim=1)
|
intermediate = torch.cat(all_intermediate, dim=1)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user