diff --git a/README.md b/README.md index 6cc601fa0..5803995b8 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,10 @@ When using Windows, open the **Windows Powershell** app. Then observe you are at pip install xformers==0.0.26.post1 pip install --no-build-isolation git+https://github.com/hiddenswitch/ComfyUI.git ``` + For improved performance when using the language models on Windows, CUDA 12.1 and PyTorch 2.3.0, add: + ```shell + pip install flash-attn @ https://github.com/AppMana/appmana-comfyui-nodes-extramodels/releases/download/v0.0.0-flash_attn/flash_attn-2.5.9.post1-cp311-cp311-win_amd64.whl + ``` Flash Attention as implemented in PyTorch is not functional on any version of Windows. ComfyUI will always run with "memory efficient attention" in practice on this platform. This is distinct from the `flash-attn` package.
**Advanced**: If you are running in Google Collab or another environment which has already installed `torch` for you, disable build isolation, and the package will recognize your currently installed torch. ```shell @@ -291,6 +295,8 @@ ComfyUI LTS supports text and multi-modal LLM models from the `transformers` eco In this example, LLAVA-NEXT (LLAVA 1.6) is prompted to describe an image. +You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-3-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json). + # Video Workflows ComfyUI LTS supports video workflows with AnimateDiff Evolved. diff --git a/comfy/language/language_types.py b/comfy/language/language_types.py index a6d26f64f..c77dd6f1d 100644 --- a/comfy/language/language_types.py +++ b/comfy/language/language_types.py @@ -20,5 +20,5 @@ class ProcessorResult(TypedDict): pixel_values: NotRequired[torch.Tensor] images: NotRequired[torch.Tensor] - inputs: BatchEncoding + inputs: NotRequired[BatchEncoding] image_sizes: NotRequired[torch.Tensor] diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index 9fa4ca1bd..2b408a395 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -135,14 +135,14 @@ class TransformersManagedModel(ModelManageable): if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"): processor.image_processor.do_rescale = False - def tokenize(self, prompt: str, images: List[torch.Tensor] | torch.Tensor, chat_template: str) -> ProcessorResult: + def tokenize(self, prompt: str, images: List[torch.Tensor] | torch.Tensor, chat_template: str | None = None) -> ProcessorResult: tokenizer = self.tokenizer assert tokenizer is not None assert hasattr(tokenizer, "decode") # try to retrieve a matching chat template chat_template = chat_template or tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None - if chat_template is None: + if chat_template is None and self.config_dict is not None and "_name_or_path" in self.config_dict: candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in self.config_dict["_name_or_path"] or name in self.model.name_or_path] if len(candidate_chat_templates) > 0: filename, chat_template = candidate_chat_templates[0] diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 5a2719348..4f3bb5e07 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -409,7 +409,9 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = { 'JingyeChen22/textdiffuser2_layout_planner', 'JingyeChen22/textdiffuser2-full-ft', 'microsoft/Phi-3-mini-4k-instruct', - 'llava-hf/llava-v1.6-mistral-7b-hf' + 'llava-hf/llava-v1.6-mistral-7b-hf', + 'facebook/nllb-200-distilled-1.3B', + 'THUDM/chatglm3-6b', } KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy/model_management.py b/comfy/model_management.py index 0f3dac895..261f1c62d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -689,8 +689,8 @@ def unet_initial_load_device(parameters, dtype): return cpu_dev -def maximum_vram_for_weights(device=None): - return (get_total_memory(device) * 0.88 - minimum_inference_memory()) +def maximum_vram_for_weights(device=None) -> int: + return get_total_memory(device) * 0.88 - minimum_inference_memory() def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)): diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index c81b14570..39bc8dab5 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -10,8 +10,14 @@ from typing import Any, Dict, Optional, List, Callable, Union import torch from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \ - PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \ - LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel, AutoModelForCausalLM + PreTrainedTokenizerBase, PretrainedConfig, AutoProcessor, BatchFeature, AutoModel, AutoModelForCausalLM, \ + AutoModelForSeq2SeqLM +from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, \ + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, AutoModelForVision2Seq +from transformers.models.m2m_100.tokenization_m2m_100 import \ + FAIRSEQ_LANGUAGE_CODES as tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES +from transformers.models.nllb.tokenization_nllb import \ + FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES from typing_extensions import TypedDict from comfy import model_management @@ -27,7 +33,7 @@ _AUTO_CHAT_TEMPLATE = "default" # add llava support try: - from llava import model + from llava import model as _llava_model_side_effects logging.debug("Additional LLaVA models are now supported") except ImportError as exc: @@ -241,39 +247,70 @@ class TransformersLoader(CustomNode): with comfy_tqdm(): from_pretrained_kwargs = { "pretrained_model_name_or_path": ckpt_name, - "torch_dtype": unet_dtype(), - "device_map": get_torch_device_name(unet_offload_device()), - "low_cpu_mem_usage": True, "trust_remote_code": True, **hub_kwargs } - # try: - # import flash_attn - # from_pretrained_kwargs["attn_implementation"] = "flash_attention_2" - # except ImportError: - # logging.debug("install flash_attn for improved performance using language nodes") - - config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs) - - if config_dict["model_type"] == "llava_next": - model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs) - else: - try: - model = AutoModel.from_pretrained(**from_pretrained_kwargs) - except Exception: - model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs) + # if flash attention exists, use it + # compute bitsandbytes configuration try: + import bitsandbytes + except ImportError: + pass + + config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs) + model_type = config_dict["model_type"] + # language models prefer to use bfloat16 over float16 + kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)), + "low_cpu_mem_usage": True, + "device_map": str(unet_offload_device()), }, {}) + + # if we have flash-attn installed, try to use it + try: + import flash_attn + attn_override_kwargs = { + "attn_implementation": "flash_attention_2", + **kwargs_to_try[0] + } + kwargs_to_try = (attn_override_kwargs, *kwargs_to_try) + logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried") + except ImportError: + pass + for i, props in enumerate(kwargs_to_try): try: - processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs) - except: - processor = LlavaNextProcessor.from_pretrained(**from_pretrained_kwargs) - except: - processor = None - if not isinstance(processor, ProcessorMixin): - processor = None - tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs) + if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES: + model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props) + elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: + model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props) + elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props) + else: + model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props) + if model is not None: + break + except Exception as exc_info: + if i == len(kwargs_to_try) - 1: + raise exc_info + else: + logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info) + + for i, props in enumerate(kwargs_to_try): + try: + try: + processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props) + except: + processor = None + if isinstance(processor, PreTrainedTokenizerBase): + tokenizer = processor + processor = None + else: + tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props) + if tokenizer is not None or processor is not None: + break + except Exception as exc_info: + if i == len(kwargs_to_try) - 1: + raise exc_info if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"): model.enable_xformers_memory_efficient_attention() @@ -289,6 +326,108 @@ class TransformersLoader(CustomNode): return model_managed, +class TransformersTokenize(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "model": ("MODEL",), + "prompt": ("STRING", {"default": "", "multiline": True}), + }, + } + + CATEGORY = "language" + RETURN_TYPES = (TOKENS_TYPE_NAME,) + FUNCTION = "execute" + + def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult: + return model.tokenize(prompt, [], None), + + +class TransformersM2M100LanguageCodes(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "lang_id": (tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES["m2m100"], {"default": "en"}), + }, + } + + CATEGORY = "language" + RETURN_TYPES = ("STRING",) + FUNCTION = "execute" + + def execute(self, lang_id: str) -> ValidatedNodeResult: + return lang_id, + + +class TransformersFlores200LanguageCodes(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "lang_id": (tokenization_nllb_FAIRSEQ_LANGUAGE_CODES, {"default": "eng_Latn"}), + }, + } + + CATEGORY = "language" + RETURN_TYPES = ("STRING",) + FUNCTION = "execute" + + def execute(self, lang_id: str) -> ValidatedNodeResult: + return lang_id, + + +class TransformersTranslationTokenize(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "model": ("MODEL",), + "prompt": ("STRING", {"default": "", "multiline": True}), + "src_lang": ("STRING", {}), + "tgt_lang": ("STRING", {}), + }, + } + + CATEGORY = "language" + RETURN_TYPES = (TOKENS_TYPE_NAME,) + FUNCTION = "execute" + + def execute(self, model: TransformersManagedModel, prompt: str, src_lang: str, tgt_lang: str) -> ValidatedNodeResult: + tokenizer = model.tokenizer + + if hasattr(tokenizer, "src_lang"): + prev_src_lang = tokenizer.src_lang + else: + prev_src_lang = None + if hasattr(tokenizer, "tgt_lang"): + prev_tgt_lang = tokenizer.tgt_lang + else: + prev_tgt_lang = None + + try: + if hasattr(tokenizer, "_build_translation_inputs"): + encoded = tokenizer._build_translation_inputs( + prompt, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang + ) + else: + tokenizer.src_lang = src_lang + tokenizer.tgt_lang = tgt_lang + + encoded = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) + encoded["input_ids"] = encoded["input_ids"].to(device=model.load_device) + encoded["attention_mask"] = encoded["attention_mask"].to(device=model.load_device) + encoded["src_lang"] = src_lang + encoded["tgt_lang"] = tgt_lang + return encoded, + finally: + if prev_src_lang is not None: + tokenizer.src_lang = prev_src_lang + if prev_tgt_lang is not None: + tokenizer.tgt_lang = prev_tgt_lang + + class OneShotInstructTokenize(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: @@ -352,6 +491,7 @@ class TransformersGenerate(CustomNode): **kwargs ): tokens = copy.copy(tokens) + tokens_original = copy.copy(tokens) sampler = sampler or {} generate_kwargs = copy.copy(sampler) load_models_gpu([model]) @@ -363,6 +503,16 @@ class TransformersGenerate(CustomNode): prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature]))) gen_sig_keys = generate_signature.keys() + if "tgt_lang" in tokens: + to_delete.add("tgt_lang") + to_delete.add("src_lang") + to_delete.discard("input_ids") + if "forced_bos_token_id" in tokens: + to_delete.discard("forced_bos_token_id") + elif hasattr(tokenizer, "convert_tokens_to_ids"): + generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"]) + else: + logging.warning(f"tokenizer {tokenizer} unexpected for translation task") if "input_ids" in tokens and "inputs" in tokens: if "input_ids" in gen_sig_keys: to_delete.add("inputs") @@ -370,14 +520,19 @@ class TransformersGenerate(CustomNode): to_delete.add("input_ids") for unused_kwarg in to_delete: tokens.pop(unused_kwarg) - logging.info(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing") + logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing") # images should be moved to model for key in ("images", "pixel_values"): if key in tokens: tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype()) + + # sets up inputs inputs = tokens - progress_logits_processor = _ProgressLogitsProcessor(model) + + # used to determine if text streaming is supported + num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams) + progress_bar: ProgressBar with comfy_progress(total=max_new_tokens) as progress_bar: # todo: deal with batches correctly, don't assume batch size 1 @@ -388,34 +543,38 @@ class TransformersGenerate(CustomNode): nonlocal token_count nonlocal progress_bar - # todo: this has to be more mathematically sensible - eos_token_probability = progress_logits_processor.eos_probability token_count += 1 - value = max(eos_token_probability * max_new_tokens, token_count) preview = TransformerStreamedProgress(next_token=next_token) - progress_bar.update_absolute(value, total=max_new_tokens, preview_image_or_output=preview) + progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview) text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True) with seed_for_block(seed): + if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs: + inputs.pop("attention_mask") output_ids = transformers_model.generate( **inputs, - logits_processor=LogitsProcessorList([progress_logits_processor]), - streamer=text_streamer, + streamer=text_streamer if num_beams <= 1 else None, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty if repetition_penalty != 0 else None, **generate_kwargs ) - if transformers_model.config.is_encoder_decoder: - start_position = 1 - else: + if not transformers_model.config.is_encoder_decoder: start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1] - output_ids = output_ids[:, start_position:] + output_ids = output_ids[:, start_position:] + if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original: + prev_src_lang = tokenizer.src_lang + tokenizer.src_lang = tokens_original["src_lang"] + else: + prev_src_lang = None # todo: is this redundant consider I'm decoding in the on_finalized_text block? - outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - + try: + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + finally: + if prev_src_lang is not None: + tokenizer.src_lang = prev_src_lang # gpu-loaded stuff like images can now be unloaded if hasattr(tokens, "to"): del tokens @@ -459,6 +618,10 @@ for cls in ( TransformersImageProcessorLoader, TransformersGenerate, OneShotInstructTokenize, + TransformersM2M100LanguageCodes, + TransformersTokenize, + TransformersFlores200LanguageCodes, + TransformersTranslationTokenize, PreviewString, ): NODE_CLASS_MAPPINGS[cls.__name__] = cls diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index 9d0eda246..9b848c253 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -187,7 +187,7 @@ class HashImage(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "images": ("IMAGE",), + "images": ("IMAGE", {}), } } @@ -270,7 +270,7 @@ class DevNullUris(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "images": ("IMAGE",), + "images": ("IMAGE", {}), } } @@ -332,8 +332,8 @@ class UriFormat(CustomNode): "output_dir_format_name": ("STRING", {"default": "output"}), }, "optional": { - "images": ("IMAGE",), - "image_hashes": ("IMAGE_HASHES",), + "images": ("IMAGE", {}), + "image_hashes": ("IMAGE_HASHES", {}), }, "hidden": { "prompt": "PROMPT", @@ -394,7 +394,7 @@ class ImageExifMerge(CustomNode): return { "required": {}, "optional": { - f"value{i}": ("EXIF",) for i in range(5) + f"value{i}": ("EXIF", {}) for i in range(5) } } @@ -421,7 +421,7 @@ class ImageExifCreationDateAndBatchNumber(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "images": ("IMAGE",), + "images": ("IMAGE", {}), } } @@ -446,7 +446,7 @@ class ImageExif(ImageExifBase, CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "images": ("IMAGE",), + "images": ("IMAGE", {}), }, "optional": { **_common_image_metadatas @@ -463,7 +463,7 @@ class ImageExifUncommon(ImageExifBase, CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "images": ("IMAGE",), + "images": ("IMAGE", {}), }, "optional": { **_common_image_metadatas, @@ -509,9 +509,9 @@ class SaveImagesResponse(CustomNode): "pil_save_format": ("STRING", {"default": "png"}), }, "optional": { - "exif": ("EXIF",), - "metadata_uris": ("URIS",), - "local_uris": ("URIS",), + "exif": ("EXIF", {}), + "metadata_uris": ("URIS", {}), + "local_uris": ("URIS", {}), **_open_api_common_schema, }, "hidden": { diff --git a/tests/inference/test_workflows.py b/tests/inference/test_workflows.py index 8bed2b158..6ab02f968 100644 --- a/tests/inference/test_workflows.py +++ b/tests/inference/test_workflows.py @@ -36,7 +36,7 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu: except (ImportError, ModuleNotFoundError): pytest.skip("requires torchaudio") - workflow = json.loads(workflow_file.read_text()) + workflow = json.loads(workflow_file.read_text(encoding="utf8")) prompt = Prompt.validate(workflow) # todo: add all the models we want to test a bit m2ore elegantly diff --git a/tests/inference/workflows/translation-0.json b/tests/inference/workflows/translation-0.json new file mode 100644 index 000000000..5c70f1705 --- /dev/null +++ b/tests/inference/workflows/translation-0.json @@ -0,0 +1,85 @@ +{ + "1": { + "inputs": { + "ckpt_name": "facebook/nllb-200-distilled-1.3B", + "subfolder": "" + }, + "class_type": "TransformersLoader", + "_meta": { + "title": "TransformersLoader" + } + }, + "2": { + "inputs": { + "max_new_tokens": 512, + "repetition_penalty": 1, + "seed": 1811645458, + "use_cache": true, + "__tokens": "spa_Latn ¡Hola ahí, David!", + "model": [ + "1", + 0 + ], + "tokens": [ + "9", + 0 + ] + }, + "class_type": "TransformersGenerate", + "_meta": { + "title": "TransformersGenerate" + } + }, + "5": { + "inputs": { + "value": [ + "2", + 0 + ], + "output": "¡Hola ahí, David!" + }, + "class_type": "PreviewString", + "_meta": { + "title": "PreviewString" + } + }, + "9": { + "inputs": { + "prompt": "Hello there, David!", + "src_lang": [ + "12", + 0 + ], + "tgt_lang": [ + "13", + 0 + ], + "model": [ + "1", + 0 + ] + }, + "class_type": "TransformersTranslationTokenize", + "_meta": { + "title": "TransformersTranslationTokenize" + } + }, + "12": { + "inputs": { + "lang_id": "eng_Latn" + }, + "class_type": "TransformersFlores200LanguageCodes", + "_meta": { + "title": "TransformersFlores200LanguageCodes" + } + }, + "13": { + "inputs": { + "lang_id": "spa_Latn" + }, + "class_type": "TransformersFlores200LanguageCodes", + "_meta": { + "title": "TransformersFlores200LanguageCodes" + } + } +} \ No newline at end of file diff --git a/tests/inference/workflows/translation-1.json b/tests/inference/workflows/translation-1.json new file mode 100644 index 000000000..0c50b5e44 --- /dev/null +++ b/tests/inference/workflows/translation-1.json @@ -0,0 +1,100 @@ +{ + "1": { + "inputs": { + "ckpt_name": [ + "14", + 0 + ], + "subfolder": "" + }, + "class_type": "TransformersLoader", + "_meta": { + "title": "TransformersLoader" + } + }, + "2": { + "inputs": { + "max_new_tokens": 512, + "repetition_penalty": 1, + "seed": 3541256804, + "use_cache": true, + "model": [ + "1", + 0 + ], + "tokens": [ + "9", + 0 + ] + }, + "class_type": "TransformersGenerate", + "_meta": { + "title": "TransformersGenerate" + } + }, + "5": { + "inputs": { + "value": [ + "2", + 0 + ], + "output": "en. I'm an AI." + }, + "class_type": "PreviewString", + "_meta": { + "title": "PreviewString" + } + }, + "9": { + "inputs": { + "prompt": "こんにちは。私はAIです。", + "src_lang": [ + "15", + 0 + ], + "tgt_lang": [ + "16", + 0 + ], + "model": [ + "1", + 0 + ] + }, + "class_type": "TransformersTranslationTokenize", + "_meta": { + "title": "TransformersTranslationTokenize" + } + }, + "14": { + "inputs": { + "value": "Mitsua/elan-mt-bt-ja-en", + "name": "", + "title": "", + "description": "", + "__required": true + }, + "class_type": "StringEnumRequestParameter", + "_meta": { + "title": "StringEnumRequestParameter" + } + }, + "15": { + "inputs": { + "lang_id": "ja" + }, + "class_type": "TransformersM2M100LanguageCodes", + "_meta": { + "title": "TransformersM2M100LanguageCodes" + } + }, + "16": { + "inputs": { + "lang_id": "en" + }, + "class_type": "TransformersM2M100LanguageCodes", + "_meta": { + "title": "TransformersM2M100LanguageCodes" + } + } +} \ No newline at end of file