mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-24 18:43:36 +08:00
Make more intermediate values follow the intermediate dtype. (#13051)
This commit is contained in:
parent
b67ed2a45f
commit
dcd659590f
@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.to(comfy.model_management.intermediate_device())
|
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.to(comfy.model_management.intermediate_device())
|
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
return samples
|
return samples
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
|
|||||||
out, pooled = o[:2]
|
out, pooled = o[:2]
|
||||||
|
|
||||||
if pooled is not None:
|
if pooled is not None:
|
||||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
first_pooled = pooled[0:1].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype())
|
||||||
else:
|
else:
|
||||||
first_pooled = pooled
|
first_pooled = pooled
|
||||||
|
|
||||||
@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
|
|||||||
output.append(z)
|
output.append(z)
|
||||||
|
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
r = (out[-1:].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled)
|
||||||
else:
|
else:
|
||||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled)
|
||||||
|
|
||||||
if len(o) > 2:
|
if len(o) > 2:
|
||||||
extra = {}
|
extra = {}
|
||||||
for k in o[2]:
|
for k in o[2]:
|
||||||
v = o[2][k]
|
v = o[2][k]
|
||||||
if k == "attention_mask":
|
if k == "attention_mask":
|
||||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype())
|
||||||
extra[k] = v
|
extra[k] = v
|
||||||
|
|
||||||
r = r + (extra,)
|
r = r + (extra,)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user