@@ -1674,18 +1674,18 @@ def _helion_fused_linear_jsd_kernel(student_logits, teacher_logits, loss, temper
16741674 # src[fused_linear_jsd.py:N]: teacher_div = torch.nn.functional.kl_div(
16751675 # src[fused_linear_jsd.py:N]: torch.log(m), teacher_prob, reduction="none", log_target=True
16761676 # src[fused_linear_jsd.py:N]: ).sum(dim=-1)
1677- v_17 = teacher_prob_1 - v_16
1678- v_18 = libdevice.exp( teacher_prob_1)
1679- v_19 = v_18 * v_17
1677+ v_17 = libdevice.exp( teacher_prob_1)
1678+ v_18 = teacher_prob_1 - v_16
1679+ v_19 = v_17 * v_18
16801680 teacher_div = tl.cast(tl.sum(v_19, 1), tl.float32)
16811681 # src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True
16821682 v_20 = tl_math.log(v_15)
16831683 # src[fused_linear_jsd.py:N]: student_div = torch.nn.functional.kl_div(
16841684 # src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True
16851685 # src[fused_linear_jsd.py:N]: ).sum(dim=-1)
1686- v_21 = student_prob_1 - v_20
1687- v_22 = libdevice.exp( student_prob_1)
1688- v_23 = v_22 * v_21
1686+ v_21 = libdevice.exp( student_prob_1)
1687+ v_22 = student_prob_1 - v_20
1688+ v_23 = v_21 * v_22
16891689 student_div = tl.cast(tl.sum(v_23, 1), tl.float32)
16901690 # src[fused_linear_jsd.py:N]: batch_loss = student_div + beta * (teacher_div - student_div)
16911691 v_24 = teacher_div - student_div
0 commit comments