From b0e25488dd98a78e730f476867f08c2d531f50e1 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 13 Aug 2024 20:51:07 -0700 Subject: [PATCH] Fix tokenizer cloning --- comfy/model_patcher.py | 4 ++-- comfy/text_encoders/flux.py | 9 ++++++++- comfy/text_encoders/hydit.py | 3 +++ comfy/text_encoders/sd3_clip.py | 8 +++++++- comfy/text_encoders/spiece_tokenizer.py | 6 ++++++ 5 files changed, 26 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b4dfcaf3b..7cda39c87 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -416,10 +416,10 @@ class ModelPatcher(ModelManageable): logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + logging.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) self._memory_measurements.model_lowvram = True else: - logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024))) + logging.debug("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024))) self._memory_measurements.model_lowvram = False self._memory_measurements.lowvram_patch_counter += patch_counter self._memory_measurements.model_loaded_weight_memory = mem_counter diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 8e3f32321..3495e64bb 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -1,3 +1,5 @@ +import copy + import torch from transformers import T5TokenizerFast @@ -21,7 +23,9 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer): class FluxTokenizer: - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = dict() self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) @@ -37,6 +41,9 @@ class FluxTokenizer: def state_dict(self): return {} + def clone(self): + return copy.copy(self) + class FluxClipModel(torch.nn.Module): def __init__(self, dtype_t5=None, device="cpu", dtype=None): diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index af88c6eb2..b4b7db385 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -59,6 +59,9 @@ class HyditTokenizer: def state_dict(self): return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]} + def clone(self): + return copy.copy(self) + class HyditModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 8de556d4e..2a8c39112 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -1,3 +1,4 @@ +import copy import logging import torch @@ -24,7 +25,9 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer): class SD3Tokenizer: - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = dict() self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) @@ -42,6 +45,9 @@ class SD3Tokenizer: def state_dict(self): return dict() + def clone(self): + return copy.copy(self) + class SD3ClipModel(torch.nn.Module): def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py index 2157011d9..77ab07f7d 100644 --- a/comfy/text_encoders/spiece_tokenizer.py +++ b/comfy/text_encoders/spiece_tokenizer.py @@ -1,3 +1,5 @@ +import copy + import sentencepiece import torch @@ -36,3 +38,7 @@ class SPieceTokenizer: def serialize_model(self): return torch.ByteTensor(list(self.tokenizer.serialized_model_proto())) + + def clone(self): + return copy.copy(self) +