mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
fix nodes_train
This commit is contained in:
parent
488bb3a23f
commit
7b20e224f9
@ -15,12 +15,12 @@ import comfy.samplers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy import node_helpers
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy.weight_adapter import adapters
|
||||
from .nodes import nodes_custom_sampler
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
@ -484,9 +484,9 @@ class TrainLoraNode:
|
||||
train_sampler = TrainSampler(
|
||||
criterion, optimizer, loss_callback=loss_callback
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider = nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
|
||||
ss = nodes_custom_sampler.SamplerCustomAdvanced()
|
||||
|
||||
# yoland: this currently resize to the first image in the dataset
|
||||
|
||||
@ -500,7 +500,7 @@ class TrainLoraNode:
|
||||
)
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
||||
noise = nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
||||
|
||||
indices = torch.randperm(num_images)[:batch_size]
|
||||
ss.sample(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user