use torch.load/save and fix bad behaviors

This commit is contained in:
Kohaku-Blueleaf 2025-11-22 16:29:55 +08:00
parent 71ac74ecb8
commit 09afeeb1c4

View File

@ -1,7 +1,7 @@
import logging import logging
import os import os
import pickle
import math import math
import json
import numpy as np import numpy as np
import torch import torch
@ -383,7 +383,7 @@ class ImageProcessingNode(io.ComfyNode):
is_group = cls._detect_processing_mode() is_group = cls._detect_processing_mode()
# Auto-detect is_output_list if not explicitly set # Auto-detect is_output_list if not explicitly set
# Single processing: True (backend collects results into list) # Single processing: False (backend collects results into list)
# Group processing: True by default (can be False for single-output nodes) # Group processing: True by default (can be False for single-output nodes)
output_is_list = ( output_is_list = (
cls.is_output_list if cls.is_output_list is not None else is_group cls.is_output_list if cls.is_output_list is not None else is_group
@ -435,13 +435,7 @@ class ImageProcessingNode(io.ComfyNode):
# Individual processing: images is single item, call _process # Individual processing: images is single item, call _process
result = cls._process(images, **params) result = cls._process(images, **params)
# Wrap result based on is_output_list return io.NodeOutput(result)
if cls.is_output_list:
# Result should already be a list (or will be for individual)
return io.NodeOutput(result if is_group else [result])
else:
# Single output - wrap in list for NodeOutput
return io.NodeOutput([result])
@classmethod @classmethod
def _process(cls, image, **kwargs): def _process(cls, image, **kwargs):
@ -1395,7 +1389,7 @@ class SaveTrainingDataset(io.ComfyNode):
shard_path = os.path.join(output_dir, shard_filename) shard_path = os.path.join(output_dir, shard_filename)
with open(shard_path, "wb") as f: with open(shard_path, "wb") as f:
pickle.dump(shard_data, f, protocol=pickle.HIGHEST_PROTOCOL) torch.save(shard_data, f)
logging.info( logging.info(
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
@ -1409,8 +1403,6 @@ class SaveTrainingDataset(io.ComfyNode):
} }
metadata_path = os.path.join(output_dir, "metadata.json") metadata_path = os.path.join(output_dir, "metadata.json")
with open(metadata_path, "w") as f: with open(metadata_path, "w") as f:
import json
json.dump(metadata, f, indent=2) json.dump(metadata, f, indent=2)
logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
@ -1478,7 +1470,7 @@ class LoadTrainingDataset(io.ComfyNode):
shard_path = os.path.join(dataset_dir, shard_file) shard_path = os.path.join(dataset_dir, shard_file)
with open(shard_path, "rb") as f: with open(shard_path, "rb") as f:
shard_data = pickle.load(f) shard_data = torch.load(f)
all_latents.extend(shard_data["latents"]) all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"]) all_conditioning.extend(shard_data["conditioning"])