mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
training context var
This commit is contained in:
parent
3593628f2d
commit
ec61c02bf6
@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
|
||||||
|
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||||
|
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||||
|
|
||||||
|
return x_out.reshape(*x.shape).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||||
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
|
if comfy.model_management.in_training:
|
||||||
|
return _apply_rope(xq, xk, freqs_cis)
|
||||||
|
else:
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
def apply_rope1(x, freqs_cis):
|
||||||
|
if comfy.model_management.in_training:
|
||||||
|
return _apply_rope1(x, freqs_cis)
|
||||||
|
else:
|
||||||
|
return q_apply_rope1(x, freqs_cis)
|
||||||
except:
|
except:
|
||||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
apply_rope = _apply_rope
|
||||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
apply_rope1 = _apply_rope1
|
||||||
|
|
||||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
|
||||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
|
||||||
|
|
||||||
return x_out.reshape(*x.shape).type_as(x)
|
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
|
||||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
|
||||||
|
|||||||
@ -47,6 +47,11 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Training Related State
|
||||||
|
in_training = False
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
float8_types = []
|
float8_types = []
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user