mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
use torch.load/save and fix bad behaviors
This commit is contained in:
parent
71ac74ecb8
commit
09afeeb1c4
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import math
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -383,7 +383,7 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
is_group = cls._detect_processing_mode()
|
||||
|
||||
# Auto-detect is_output_list if not explicitly set
|
||||
# Single processing: True (backend collects results into list)
|
||||
# Single processing: False (backend collects results into list)
|
||||
# Group processing: True by default (can be False for single-output nodes)
|
||||
output_is_list = (
|
||||
cls.is_output_list if cls.is_output_list is not None else is_group
|
||||
@ -435,13 +435,7 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
# Individual processing: images is single item, call _process
|
||||
result = cls._process(images, **params)
|
||||
|
||||
# Wrap result based on is_output_list
|
||||
if cls.is_output_list:
|
||||
# Result should already be a list (or will be for individual)
|
||||
return io.NodeOutput(result if is_group else [result])
|
||||
else:
|
||||
# Single output - wrap in list for NodeOutput
|
||||
return io.NodeOutput([result])
|
||||
return io.NodeOutput(result)
|
||||
|
||||
@classmethod
|
||||
def _process(cls, image, **kwargs):
|
||||
@ -1395,7 +1389,7 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
shard_path = os.path.join(output_dir, shard_filename)
|
||||
|
||||
with open(shard_path, "wb") as f:
|
||||
pickle.dump(shard_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
torch.save(shard_data, f)
|
||||
|
||||
logging.info(
|
||||
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
|
||||
@ -1409,8 +1403,6 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
}
|
||||
metadata_path = os.path.join(output_dir, "metadata.json")
|
||||
with open(metadata_path, "w") as f:
|
||||
import json
|
||||
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
|
||||
@ -1478,7 +1470,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
shard_path = os.path.join(dataset_dir, shard_file)
|
||||
|
||||
with open(shard_path, "rb") as f:
|
||||
shard_data = pickle.load(f)
|
||||
shard_data = torch.load(f)
|
||||
|
||||
all_latents.extend(shard_data["latents"])
|
||||
all_conditioning.extend(shard_data["conditioning"])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user