Skip to content

Commit a4ba9af

Browse files
authored
change amp policy of layer_norm to fall through (#1703)
1 parent a9c2646 commit a4ba9af

File tree

2 files changed

+1
-33
lines changed

2 files changed

+1
-33
lines changed

csrc/aten/amp/autocast_mode.cpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -595,17 +595,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) {
595595
"softplus",
596596
Tensor(const Tensor&, const Scalar&, const Scalar&),
597597
fp32)
598-
KERNEL_XPU(
599-
ADD_NS(layer_norm),
600-
"layer_norm",
601-
Tensor(
602-
const Tensor&,
603-
IntArrayRef,
604-
const c10::optional<Tensor>&,
605-
const c10::optional<Tensor>&,
606-
double,
607-
bool),
608-
fp32)
609598
KERNEL_XPU(
610599
ADD_NS(group_norm),
611600
"group_norm",
@@ -848,27 +837,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) {
848837
int64_t,
849838
c10::optional<c10::string_view>),
850839
fp32)
851-
// The macro doesn't like these (I think it chokes on commas inside <>) so
852-
// write them manually
853-
m.impl(
854-
TORCH_SELECTIVE_NAME("aten::native_layer_norm"),
855-
TORCH_FN((&WrapFunction<
856-
CastPolicy::fp32,
857-
DeviceType::XPU,
858-
std::tuple<Tensor, Tensor, Tensor>(
859-
const Tensor&,
860-
IntArrayRef,
861-
const c10::optional<Tensor>&,
862-
const c10::optional<Tensor>&,
863-
double),
864-
std::tuple<Tensor, Tensor, Tensor>(
865-
const Tensor&,
866-
IntArrayRef,
867-
const c10::optional<Tensor>&,
868-
const c10::optional<Tensor>&,
869-
double),
870-
&ADD_NS(native_layer_norm)>::type::call)));
871-
872840
// promote
873841
KERNEL_XPU(ADD_NS(cat), "cat", Tensor(TensorList, int64_t), promote)
874842
KERNEL_XPU(ADD_NS(stack), "stack", Tensor(TensorList, int64_t), promote)

docs/tutorials/features/amp.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ If an op is unlisted, we assume it's numerically stable in `bfloat16` or `float1
9595

9696
#### Ops that can autocast to `float32`
9797

98-
`binary_cross_entropy`, `binary_cross_entropy_with_logits`, `log_softmax`, `nll_loss`, `nll_loss2d`, `nll_loss_nd`, `cross_entropy_loss`, `fft_fft`, `fft_ifft`, `fft_fft2`, `fft_ifft2`, `fft_fftn`, `fft_ifftn`, `fft_rfft`, `fft_irfft`, `fft_rfft2`, `fft_irfft2`, `fft_rfftn`, `fft_irfftn`, `fft_hfft`, `fft_ihfft`, `acos`, `asin`, `cosh`, `erfinv`, `exp`, `expm1`, `log`, `log10`, `log2`, `log1p`, `reciprocal`, `rsqrt`, `sinh`, `tan`, `pow`, `softplus`, `layer_norm`, `group_norm`, `frobenius_norm`, `nuclear_norm`, `cosine_similarity`, `poisson_nll_loss`, `cosine_embedding_loss`, `hinge_embedding_loss`, `kl_div`, `l1_loss`, `smooth_l1_loss `, `huber_loss`, `mse_loss`, `margin_ranking_loss`, `multilabel_margin_loss`, `soft_margin_loss`, `triplet_margin_loss`, `multi_margin_loss`, `dist`, `pdist`, `cdist`, `renorm`, `native_layer_norm`
98+
`binary_cross_entropy`, `binary_cross_entropy_with_logits`, `log_softmax`, `nll_loss`, `nll_loss2d`, `nll_loss_nd`, `cross_entropy_loss`, `fft_fft`, `fft_ifft`, `fft_fft2`, `fft_ifft2`, `fft_fftn`, `fft_ifftn`, `fft_rfft`, `fft_irfft`, `fft_rfft2`, `fft_irfft2`, `fft_rfftn`, `fft_irfftn`, `fft_hfft`, `fft_ihfft`, `acos`, `asin`, `cosh`, `erfinv`, `exp`, `expm1`, `log`, `log10`, `log2`, `log1p`, `reciprocal`, `rsqrt`, `sinh`, `tan`, `pow`, `softplus`, `group_norm`, `frobenius_norm`, `nuclear_norm`, `cosine_similarity`, `poisson_nll_loss`, `cosine_embedding_loss`, `hinge_embedding_loss`, `kl_div`, `l1_loss`, `smooth_l1_loss `, `huber_loss`, `mse_loss`, `margin_ranking_loss`, `multilabel_margin_loss`, `soft_margin_loss`, `triplet_margin_loss`, `multi_margin_loss`, `dist`, `pdist`, `cdist`, `renorm`
9999

100100
#### Ops that promote to the widest input type
101101

0 commit comments

Comments
 (0)