Fix tokenizer cloning

This commit is contained in:
doctorpangloss 2024-08-13 20:51:07 -07:00
parent 0549f35e85
commit b0e25488dd
5 changed files with 26 additions and 4 deletions

View File

@ -416,10 +416,10 @@ class ModelPatcher(ModelManageable):
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
logging.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self._memory_measurements.model_lowvram = True
else:
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
logging.debug("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
self._memory_measurements.model_lowvram = False
self._memory_measurements.lowvram_patch_counter += patch_counter
self._memory_measurements.model_loaded_weight_memory = mem_counter

View File

@ -1,3 +1,5 @@
import copy
import torch
from transformers import T5TokenizerFast
@ -21,7 +23,9 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@ -37,6 +41,9 @@ class FluxTokenizer:
def state_dict(self):
return {}
def clone(self):
return copy.copy(self)
class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None):

View File

@ -59,6 +59,9 @@ class HyditTokenizer:
def state_dict(self):
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
def clone(self):
return copy.copy(self)
class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):

View File

@ -1,3 +1,4 @@
import copy
import logging
import torch
@ -24,7 +25,9 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@ -42,6 +45,9 @@ class SD3Tokenizer:
def state_dict(self):
return dict()
def clone(self):
return copy.copy(self)
class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):

View File

@ -1,3 +1,5 @@
import copy
import sentencepiece
import torch
@ -36,3 +38,7 @@ class SPieceTokenizer:
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
def clone(self):
return copy.copy(self)