mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Quality of life improvements
- export_custom_nodes() finds all the classes that inherit from CustomNode and exports them correctly for custom node discovery to find - regular expressions - additional string formatting and parsing nodes
This commit is contained in:
parent
cf08b11132
commit
31b6b53236
@ -1,10 +1,14 @@
|
||||
import hashlib
|
||||
|
||||
from .cli_args import args
|
||||
|
||||
from PIL import ImageFile, UnidentifiedImageError
|
||||
|
||||
def conditioning_set_values(conditioning, values={}):
|
||||
from .cli_args import args
|
||||
from .nodes.package_typing import CustomNode
|
||||
|
||||
|
||||
def conditioning_set_values(conditioning, values: dict = None):
|
||||
if values is None:
|
||||
values = {}
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
@ -14,11 +18,12 @@ def conditioning_set_values(conditioning, values={}):
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def pillow(fn, arg):
|
||||
prev_value = None
|
||||
try:
|
||||
x = fn(arg)
|
||||
except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes ComfyUI issue #3416
|
||||
except (OSError, UnidentifiedImageError, ValueError): # PIL issues #4472 and #2445, also fixes ComfyUI issue #3416
|
||||
prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
x = fn(arg)
|
||||
@ -27,6 +32,7 @@ def pillow(fn, arg):
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
|
||||
return x
|
||||
|
||||
|
||||
def hasher():
|
||||
hashfuncs = {
|
||||
"md5": hashlib.md5,
|
||||
@ -35,3 +41,35 @@ def hasher():
|
||||
"sha512": hashlib.sha512
|
||||
}
|
||||
return hashfuncs[args.default_hashing_function]
|
||||
|
||||
|
||||
def export_custom_nodes():
|
||||
"""
|
||||
Finds all classes in the current module that extend CustomNode and creates
|
||||
a NODE_CLASS_MAPPINGS dictionary mapping class names to class objects.
|
||||
Must be called from within the module where the CustomNode classes are defined.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
# Get the calling module
|
||||
frame = inspect.currentframe()
|
||||
try:
|
||||
module = inspect.getmodule(frame.f_back)
|
||||
|
||||
custom_nodes = {}
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if (inspect.isclass(obj) and
|
||||
CustomNode in obj.__mro__ and
|
||||
obj != CustomNode):
|
||||
custom_nodes[name] = obj
|
||||
if hasattr(module, 'NODE_CLASS_MAPPINGS'):
|
||||
node_class_mappings: dict = getattr(module, 'NODE_CLASS_MAPPINGS')
|
||||
node_class_mappings.update(custom_nodes)
|
||||
else:
|
||||
setattr(module, 'NODE_CLASS_MAPPINGS', custom_nodes)
|
||||
|
||||
finally:
|
||||
# Clean up circular reference
|
||||
del frame
|
||||
|
||||
return custom_nodes
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
|
||||
Callable, List, Type
|
||||
Callable, List, Type, runtime_checkable
|
||||
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
@ -64,7 +64,7 @@ NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
|
||||
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
||||
|
||||
# numpy seeds must be between 0 and 2**32 - 1
|
||||
Seed = ("INT", {"default": 0, "min": 0, "max": 2**32 - 1})
|
||||
Seed = ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1})
|
||||
Seed64 = ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff})
|
||||
SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})]
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from functools import reduce
|
||||
from operator import add, mul, pow
|
||||
from typing import Any
|
||||
|
||||
from comfy.comfy_types import IO
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
|
||||
|
||||
@ -527,32 +530,38 @@ class FloatToInt(CustomNode):
|
||||
return int(value),
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
FloatAdd,
|
||||
FloatSubtract,
|
||||
FloatMultiply,
|
||||
FloatDivide,
|
||||
FloatPower,
|
||||
FloatMin,
|
||||
FloatMax,
|
||||
FloatAbs,
|
||||
FloatAverage,
|
||||
FloatLerp,
|
||||
FloatInverseLerp,
|
||||
FloatClamp,
|
||||
IntAdd,
|
||||
IntSubtract,
|
||||
IntMultiply,
|
||||
IntDivide,
|
||||
IntMod,
|
||||
IntPower,
|
||||
IntMin,
|
||||
IntMax,
|
||||
IntAbs,
|
||||
IntAverage,
|
||||
IntLerp,
|
||||
IntInverseLerp,
|
||||
IntClamp,
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
class StringToInt(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("STRING", {}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value: str = "0"):
|
||||
return int(value),
|
||||
|
||||
|
||||
class StringToFloat(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("STRING", {}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value: str = "0"):
|
||||
return float(value),
|
||||
|
||||
|
||||
export_custom_nodes()
|
||||
|
||||
@ -20,6 +20,7 @@ from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWA
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
||||
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult, Seed
|
||||
|
||||
_AUTO_CHAT_TEMPLATE = "default"
|
||||
@ -370,7 +371,7 @@ class PreviewString(CustomNode):
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
CATEGORY = "strings"
|
||||
FUNCTION = "execute"
|
||||
RETURN_TYPES = ("STRING",)
|
||||
OUTPUT_NODE = True
|
||||
@ -392,7 +393,7 @@ class SaveString(CustomNode):
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
CATEGORY = "strings"
|
||||
FUNCTION = "execute"
|
||||
OUTPUT_NODE = True
|
||||
RETURN_TYPES = ()
|
||||
@ -412,24 +413,4 @@ class SaveString(CustomNode):
|
||||
return {"ui": {"string": value}}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
TransformerTopKSampler,
|
||||
TransformerTopPSampler,
|
||||
TransformerTemperatureSampler,
|
||||
TransformerGreedySampler,
|
||||
TransformerContrastiveSearchSampler,
|
||||
TransformerBeamSearchSampler,
|
||||
TransformerMergeSamplers,
|
||||
TransformersLoader,
|
||||
TransformersImageProcessorLoader,
|
||||
TransformersGenerate,
|
||||
OneShotInstructTokenize,
|
||||
TransformersM2M100LanguageCodes,
|
||||
TransformersTokenize,
|
||||
TransformersFlores200LanguageCodes,
|
||||
TransformersTranslationTokenize,
|
||||
PreviewString,
|
||||
SaveString,
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
export_custom_nodes()
|
||||
|
||||
107
comfy_extras/nodes/nodes_regexp.py
Normal file
107
comfy_extras/nodes/nodes_regexp.py
Normal file
@ -0,0 +1,107 @@
|
||||
import re
|
||||
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
|
||||
MATCH_TYPE_NAME = "MATCH"
|
||||
|
||||
|
||||
class RegexFlags(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
flags_ = {flag_name: ("BOOLEAN", {}) for flag_name in ["ASCII", "IGNORECASE", "LOCALE", "MULTILINE", "DOTALL", "VERBOSE",
|
||||
"UNICODE", "NOFLAG"]}
|
||||
return {
|
||||
"required": flags_,
|
||||
}
|
||||
|
||||
CATEGORY = "regular_expressions"
|
||||
FUNCTION = "execute"
|
||||
RETURN_TYPES = ("INT",)
|
||||
|
||||
def execute(self, **kwargs) -> tuple[int]:
|
||||
flags = re.RegexFlag.NOFLAG
|
||||
|
||||
for name, on in kwargs.items():
|
||||
if on:
|
||||
flags |= re.RegexFlag[name]
|
||||
|
||||
return int(flags),
|
||||
|
||||
|
||||
class Regex(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"pattern": ("STRING", {}),
|
||||
"string": ("STRING", {}),
|
||||
},
|
||||
"optional": {
|
||||
"flags": ("INT", {"min": 0, "default": 0})
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "regular_expressions"
|
||||
FUNCTION = "execute"
|
||||
RETURN_TYPES = (MATCH_TYPE_NAME,)
|
||||
|
||||
def execute(self, pattern: str = "", string: str = "", flags: int = 0) -> tuple[re.Match]:
|
||||
return re.match(pattern=pattern, string=string, flags=flags),
|
||||
|
||||
|
||||
class RegexMatchGroupByIndex(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"match": (MATCH_TYPE_NAME, {}),
|
||||
"index": ("INT", {}),
|
||||
},
|
||||
}
|
||||
|
||||
CATEGORY = "regular_expressions"
|
||||
FUNCTION = "execute"
|
||||
RETURN_TYPES = (MATCH_TYPE_NAME,)
|
||||
|
||||
def execute(self, match: re.Match, index: int = 0) -> tuple[str]:
|
||||
return match.group(index),
|
||||
|
||||
|
||||
class RegexMatchGroupByName(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"match": (MATCH_TYPE_NAME, {}),
|
||||
"name": ("STRING", {}),
|
||||
},
|
||||
}
|
||||
|
||||
CATEGORY = "regular_expressions"
|
||||
FUNCTION = "execute"
|
||||
RETURN_TYPES = (MATCH_TYPE_NAME,)
|
||||
|
||||
def execute(self, match: re.Match, name: str = "") -> tuple[str]:
|
||||
return match.group(name),
|
||||
|
||||
|
||||
class RegexMatchExpand(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"match": (MATCH_TYPE_NAME, {}),
|
||||
"template": ("STRING", {}),
|
||||
},
|
||||
}
|
||||
|
||||
CATEGORY = "regular_expressions"
|
||||
FUNCTION = "execute"
|
||||
RETURN_TYPES = (MATCH_TYPE_NAME,)
|
||||
|
||||
def execute(self, match: re.Match, template: str = "") -> tuple[str]:
|
||||
return match.expand(template),
|
||||
|
||||
|
||||
export_custom_nodes()
|
||||
67
comfy_extras/nodes/nodes_strings.py
Normal file
67
comfy_extras/nodes/nodes_strings.py
Normal file
@ -0,0 +1,67 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from natsort import natsorted
|
||||
|
||||
from comfy.comfy_types import IO
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||
|
||||
|
||||
def format_value(value: Any) -> str | Any:
|
||||
"""Helper function to format values for string formatting."""
|
||||
if value is None:
|
||||
return "None"
|
||||
elif isinstance(value, torch.Tensor):
|
||||
if value.numel() > 10: # For large tensors
|
||||
shape_str = 'x'.join(str(x) for x in value.shape)
|
||||
return f"<Tensor shape={shape_str}>"
|
||||
else:
|
||||
return str(value.tolist())
|
||||
return value
|
||||
|
||||
|
||||
class StringFormat(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
optional = {f"value{i}": (IO.ANY, {"default": "", "forceInput": True}) for i in range(5)}
|
||||
optional["format"] = (IO.STRING, {"default": "{}", "multiline": True})
|
||||
return {
|
||||
"required": {},
|
||||
"optional": optional
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
CATEGORY = "strings"
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, format: str = "{}", *args: Any, **kwargs) -> ValidatedNodeResult:
|
||||
for k, v in kwargs.items():
|
||||
kwargs[k] = format_value(v)
|
||||
try:
|
||||
if any(f'{{{k}' in format for k in kwargs.keys()):
|
||||
return (format.format(**kwargs),)
|
||||
else:
|
||||
return (format.format(*[kwargs[k] for k in natsorted(kwargs.keys())]),)
|
||||
except (IndexError, KeyError) as e:
|
||||
return ("Format error: " + str(e),)
|
||||
|
||||
|
||||
class ToString(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "strings"
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value: Any):
|
||||
return str(value),
|
||||
|
||||
|
||||
export_custom_nodes()
|
||||
94
tests/unit/test_regular_expression_nodes.py
Normal file
94
tests/unit/test_regular_expression_nodes.py
Normal file
@ -0,0 +1,94 @@
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy_extras.nodes.nodes_regexp import RegexFlags, Regex, RegexMatchGroupByIndex, RegexMatchGroupByName, \
|
||||
RegexMatchExpand
|
||||
|
||||
|
||||
def test_regex_flags():
|
||||
n = RegexFlags()
|
||||
|
||||
# Test with no flags
|
||||
flags, = n.execute(ASCII=False, IGNORECASE=False, LOCALE=False, MULTILINE=False,
|
||||
DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=True)
|
||||
assert flags == 0
|
||||
|
||||
# Test single flag
|
||||
flags, = n.execute(ASCII=True, IGNORECASE=False, LOCALE=False, MULTILINE=False,
|
||||
DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=False)
|
||||
assert flags == re.ASCII
|
||||
|
||||
# Test multiple flags
|
||||
flags, = n.execute(ASCII=True, IGNORECASE=True, LOCALE=False, MULTILINE=False,
|
||||
DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=False)
|
||||
assert flags == (re.ASCII | re.IGNORECASE)
|
||||
|
||||
|
||||
def test_regex():
|
||||
n = Regex()
|
||||
|
||||
# Basic match test
|
||||
match, = n.execute(pattern=r"hello", string="hello world")
|
||||
assert match is not None
|
||||
assert match.group(0) == "hello"
|
||||
|
||||
# Test with flags
|
||||
match, = n.execute(pattern=r"HELLO", string="hello world", flags=re.IGNORECASE)
|
||||
assert match is not None
|
||||
assert match.group(0) == "hello"
|
||||
|
||||
# Test no match
|
||||
match, = n.execute(pattern=r"python", string="hello world")
|
||||
assert match is None
|
||||
|
||||
|
||||
def test_regex_match_group_by_index():
|
||||
n = RegexMatchGroupByIndex()
|
||||
regex = Regex()
|
||||
|
||||
# Test basic group
|
||||
match, = regex.execute(pattern=r"(hello) (world)", string="hello world")
|
||||
group, = n.execute(match=match, index=0)
|
||||
assert group == "hello world"
|
||||
|
||||
group, = n.execute(match=match, index=1)
|
||||
assert group == "hello"
|
||||
|
||||
group, = n.execute(match=match, index=2)
|
||||
assert group == "world"
|
||||
|
||||
|
||||
def test_regex_match_group_by_name():
|
||||
n = RegexMatchGroupByName()
|
||||
regex = Regex()
|
||||
|
||||
# Test named group
|
||||
match, = regex.execute(pattern=r"(?P<greeting>hello) (?P<subject>world)",
|
||||
string="hello world")
|
||||
|
||||
group, = n.execute(match=match, name="greeting")
|
||||
assert group == "hello"
|
||||
|
||||
group, = n.execute(match=match, name="subject")
|
||||
assert group == "world"
|
||||
|
||||
# Test non-existent group name
|
||||
with pytest.raises(IndexError):
|
||||
n.execute(match=match, name="nonexistent")
|
||||
|
||||
|
||||
def test_regex_match_expand():
|
||||
n = RegexMatchExpand()
|
||||
regex = Regex()
|
||||
|
||||
# Test basic expansion
|
||||
match, = regex.execute(pattern=r"(hello) (world)", string="hello world")
|
||||
result, = n.execute(match=match, template=r"\2, \1!")
|
||||
assert result == "world, hello!"
|
||||
|
||||
# Test named group expansion
|
||||
match, = regex.execute(pattern=r"(?P<greeting>hello) (?P<subject>world)",
|
||||
string="hello world")
|
||||
result, = n.execute(match=match, template=r"\g<subject>, \g<greeting>!")
|
||||
assert result == "world, hello!"
|
||||
93
tests/unit/test_string_nodes.py
Normal file
93
tests/unit/test_string_nodes.py
Normal file
@ -0,0 +1,93 @@
|
||||
import torch
|
||||
|
||||
from comfy_extras.nodes.nodes_strings import StringFormat
|
||||
|
||||
|
||||
def test_string_format_basic():
|
||||
n = StringFormat()
|
||||
|
||||
# Test basic string formatting
|
||||
result, = n.execute(format="Hello, {}!", value0="World")
|
||||
assert result == "Hello, World!"
|
||||
|
||||
# Test multiple values
|
||||
result, = n.execute(format="{} plus {} equals {}", value0=2, value1=2, value2=4)
|
||||
assert result == "2 plus 2 equals 4"
|
||||
|
||||
|
||||
def test_string_format_types():
|
||||
n = StringFormat()
|
||||
|
||||
# Test with different types
|
||||
result, = n.execute(format="Float: {:.2f}, Int: {}, Bool: {}",
|
||||
value0=3.14159, value1=42, value2=True)
|
||||
assert result == "Float: 3.14, Int: 42, Bool: True"
|
||||
|
||||
# Test None values
|
||||
result, = n.execute(format="{}, {}, {}", value0=None, value1="test", value2=None)
|
||||
assert result == "None, test, None"
|
||||
|
||||
|
||||
def test_string_format_tensors():
|
||||
n = StringFormat()
|
||||
|
||||
# Test small tensor
|
||||
small_tensor = torch.tensor([1, 2, 3])
|
||||
result, = n.execute(format="Tensor: {}", value0=small_tensor)
|
||||
assert result == "Tensor: [1, 2, 3]"
|
||||
|
||||
# Test large tensor
|
||||
large_tensor = torch.randn(100, 100)
|
||||
result, = n.execute(format="Large tensor: {}", value0=large_tensor)
|
||||
assert result == "Large tensor: <Tensor shape=100x100>"
|
||||
|
||||
# Test mixed tensor sizes
|
||||
small_tensor = torch.tensor([1, 2])
|
||||
large_tensor = torch.randn(50, 50)
|
||||
result, = n.execute(format="{} and {}", value0=small_tensor, value1=large_tensor)
|
||||
assert "and <Tensor shape=50x50>" in result
|
||||
assert "[1, 2]" in result
|
||||
|
||||
|
||||
def test_string_format_edge_cases():
|
||||
n = StringFormat()
|
||||
|
||||
# Test with missing values
|
||||
result, = n.execute(format="{} {} {}", value0="a", value1="b")
|
||||
assert result.startswith("Format error: ")
|
||||
|
||||
# Test with empty format string
|
||||
result, = n.execute(format="", value0="ignored")
|
||||
assert result == ""
|
||||
|
||||
# Test with no placeholders
|
||||
result, = n.execute(format="Hello World", value0="ignored")
|
||||
assert result == "Hello World"
|
||||
|
||||
# Test with named placeholders
|
||||
result, = n.execute(format="X: {value0}, Y: {value1}", value0=10, value1=20)
|
||||
assert result == "X: 10, Y: 20"
|
||||
|
||||
# Test mixing None, tensors and regular values
|
||||
tensor = torch.tensor([1, 2, 3])
|
||||
result, = n.execute(format="{}, {}, {}", value0=None, value1=tensor, value2="test")
|
||||
assert result == "None, [1, 2, 3], test"
|
||||
|
||||
|
||||
def test_string_format_tensor_edge_cases():
|
||||
n = StringFormat()
|
||||
|
||||
# Test empty tensor
|
||||
empty_tensor = torch.tensor([])
|
||||
result, = n.execute(format="Empty tensor: {}", value0=empty_tensor)
|
||||
assert result == "Empty tensor: []"
|
||||
|
||||
# Test scalar tensor
|
||||
scalar_tensor = torch.tensor(5)
|
||||
result, = n.execute(format="Scalar tensor: {}", value0=scalar_tensor)
|
||||
assert result == "Scalar tensor: 5"
|
||||
|
||||
# Test multi-dimensional small tensor
|
||||
small_2d_tensor = torch.tensor([[1, 2], [3, 4]])
|
||||
result, = n.execute(format="2D tensor: {}", value0=small_2d_tensor)
|
||||
assert result == "2D tensor: [[1, 2], [3, 4]]"
|
||||
Loading…
Reference in New Issue
Block a user