mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Fix tokenizer cloning
This commit is contained in:
parent
0549f35e85
commit
b0e25488dd
@ -416,10 +416,10 @@ class ModelPatcher(ModelManageable):
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
logging.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
self._memory_measurements.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
||||
logging.debug("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
||||
self._memory_measurements.model_lowvram = False
|
||||
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||
self._memory_measurements.model_loaded_weight_memory = mem_counter
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
@ -21,7 +23,9 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
|
||||
class FluxTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = dict()
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
@ -37,6 +41,9 @@ class FluxTokenizer:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def clone(self):
|
||||
return copy.copy(self)
|
||||
|
||||
|
||||
class FluxClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
|
||||
|
||||
@ -59,6 +59,9 @@ class HyditTokenizer:
|
||||
def state_dict(self):
|
||||
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
||||
|
||||
def clone(self):
|
||||
return copy.copy(self)
|
||||
|
||||
|
||||
class HyditModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
@ -24,7 +25,9 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = dict()
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||
@ -42,6 +45,9 @@ class SD3Tokenizer:
|
||||
def state_dict(self):
|
||||
return dict()
|
||||
|
||||
def clone(self):
|
||||
return copy.copy(self)
|
||||
|
||||
|
||||
class SD3ClipModel(torch.nn.Module):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import copy
|
||||
|
||||
import sentencepiece
|
||||
import torch
|
||||
|
||||
@ -36,3 +38,7 @@ class SPieceTokenizer:
|
||||
|
||||
def serialize_model(self):
|
||||
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|
||||
|
||||
def clone(self):
|
||||
return copy.copy(self)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user