Skip to content

Commit c0ca047

Browse files
authored
Add the function specialization for promote with ITensorListRef (#1230)
* add the function specialization for promote with ITensorListRef * update autocast doc * add comments for test_cat_promote
1 parent cf7859f commit c0ca047

File tree

4 files changed

+50
-3
lines changed

4 files changed

+50
-3
lines changed

csrc/cpu/autocast/autocast_mode.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ struct CPU_WrapFunction_<
127127
return (*F)(cpu_cached_cast(at::kFloat, args)...);
128128
case DtypeCastPolicy::promote:
129129
return (*F)(cpu_cached_cast(
130-
promote_type(get_autocast_dtype(), DeviceType::CPU, args...),
131-
args)...);
130+
promote_type(get_autocast_dtype(), args...), args)...);
132131
default:
133132
return (*F)(args...);
134133
}

csrc/cpu/autocast/autocast_mode.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@ inline c10::optional<Tensor> cpu_cached_cast(
5454
}
5555
}
5656

57+
inline std::vector<Tensor> cpu_cached_cast(
58+
at::ScalarType to_type,
59+
const at::ITensorListRef& arg) {
60+
std::vector<Tensor> vec;
61+
vec.reserve(arg.size());
62+
for (const auto& t : arg) {
63+
vec.push_back(cpu_cached_cast(to_type, t));
64+
}
65+
return vec;
66+
}
67+
5768
inline std::vector<Tensor> cpu_cached_cast(
5869
at::ScalarType to_type,
5970
const TensorList& arg) {
@@ -137,6 +148,15 @@ inline at::ScalarType prioritize(
137148
return current;
138149
}
139150

151+
inline at::ScalarType prioritize(
152+
at::ScalarType current,
153+
const at::ITensorListRef& list) {
154+
for (const auto& tensor : list) {
155+
current = prioritize(current, tensor);
156+
}
157+
return current;
158+
}
159+
140160
// Template to catch non-Tensor args (no-op that returns current best guess)
141161
template <typename T>
142162
inline at::ScalarType prioritize(at::ScalarType current, T nextArg) {

docs/tutorials/features/amp.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ If an op is unlisted, we assume it's numerically stable in `bfloat16`. If you be
9191

9292
#### Ops that can autocast to `float32`
9393

94-
`conv_transpose1d`, `conv_transpose2d`, `conv_transpose3d`, `mish`, `avg_pool3d`, `binary_cross_entropy`, `grid_sampler`, `polar`, `prod`, `quantile`, `nanquantile`, `stft`, `cdist`, `trace`, `view_as_complex`, `cholesky`, `cholesky_inverse`, `cholesky_solve`, `inverse`, `lu_solve`, `orgqr`, `ormqr`, `pinverse`, `max_unpool2d`, `max_unpool3d`, `adaptive_avg_pool3d`, `reflection_pad1d`, `reflection_pad2d`, `replication_pad1d`, `replication_pad2d`, `replication_pad3d`, `mse_loss`, `cosine_embedding_loss`, `nll_loss`, `nll_loss2d`, `hinge_embedding_loss`, `poisson_nll_loss`, `smooth_l1_loss`, `cross_entropy_loss`, `l1_loss`, `huber_loss`, `margin_ranking_loss`, `soft_margin_loss`, `triplet_margin_loss`, `multi_margin_loss`, `ctc_loss`, `kl_div`, `multilabel_margin_loss`, `binary_cross_entropy_with_logits`, `fft_fft`, `fft_ifft`, `fft_fft2`, `fft_ifft2`, `fft_fftn`, `fft_ifftn`, `fft_rfft`, `fft_irfft`, `fft_rfft2`, `fft_irfft2`, `fft_rfftn`, `fft_irfftn`, `fft_hfft`, `fft_ihfft`, `linalg_matrix_norm`, `linalg_cond`, `linalg_matrix_rank`, `linalg_solve`, `linalg_cholesky`, `linalg_svdvals`, `linalg_eigvals`, `linalg_eigvalsh`, `linalg_inv`, `linalg_householder_product`, `linalg_tensorinv`, `linalg_tensorsolve`, `fake_quantize_per_tensor_affine`, `geqrf`, `_lu_with_info`, `qr`, `svd`, `symeig`, `triangular_solve`, `fractional_max_pool2d`, `fractional_max_pool3d`, `adaptive_max_pool3d`, `multilabel_margin_loss_forward`, `linalg_qr`, `linalg_cholesky_ex`, `linalg_svd`, `linalg_eig`, `linalg_eigh`, `linalg_lstsq`, `linalg_inv_ex`
94+
`conv_transpose1d`, `conv_transpose2d`, `conv_transpose3d`, `mish`, `avg_pool3d`, `max_pool3d`, `binary_cross_entropy`, `grid_sampler`, `polar`, `prod`, `quantile`, `nanquantile`, `stft`, `cdist`, `trace`, `view_as_complex`, `cholesky`, `cholesky_inverse`, `cholesky_solve`, `inverse`, `lu_solve`, `orgqr`, `ormqr`, `pinverse`, `max_unpool2d`, `max_unpool3d`, `adaptive_avg_pool3d`, `reflection_pad1d`, `reflection_pad2d`, `replication_pad1d`, `replication_pad2d`, `replication_pad3d`, `mse_loss`, `cosine_embedding_loss`, `nll_loss`, `nll_loss2d`, `hinge_embedding_loss`, `poisson_nll_loss`, `smooth_l1_loss`, `cross_entropy_loss`, `l1_loss`, `huber_loss`, `margin_ranking_loss`, `soft_margin_loss`, `triplet_margin_loss`, `multi_margin_loss`, `ctc_loss`, `kl_div`, `multilabel_margin_loss`, `binary_cross_entropy_with_logits`, `fft_fft`, `fft_ifft`, `fft_fft2`, `fft_ifft2`, `fft_fftn`, `fft_ifftn`, `fft_rfft`, `fft_irfft`, `fft_rfft2`, `fft_irfft2`, `fft_rfftn`, `fft_irfftn`, `fft_hfft`, `fft_ihfft`, `linalg_matrix_norm`, `linalg_cond`, `linalg_matrix_rank`, `linalg_solve`, `linalg_cholesky`, `linalg_svdvals`, `linalg_eigvals`, `linalg_eigvalsh`, `linalg_inv`, `linalg_householder_product`, `linalg_tensorinv`, `linalg_tensorsolve`, `fake_quantize_per_tensor_affine`, `geqrf`, `_lu_with_info`, `qr`, `svd`, `symeig`, `triangular_solve`, `fractional_max_pool2d`, `fractional_max_pool3d`, `adaptive_max_pool3d`, `multilabel_margin_loss_forward`, `linalg_qr`, `linalg_cholesky_ex`, `linalg_svd`, `linalg_eig`, `linalg_eigh`, `linalg_lstsq`, `linalg_inv_ex`
9595

9696
#### Ops that promote to the widest input type
9797

tests/cpu/test_autocast.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,34 @@ def test_nhwc_autocast_jit_trace_model(model, x):
187187
continue
188188
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
189189

190+
# Check whether cat has done the promotion in AMP with mixed dtype inputs
191+
# since input type of cat is changed to ITensorListRef
192+
def test_cat_promote(self):
193+
class TestModel(torch.nn.Module):
194+
def __init__(self):
195+
super(TestModel, self).__init__()
196+
197+
def forward(self, a, b):
198+
return torch.cat([a, b], 0)
199+
with torch.jit.fuser("none"):
200+
# In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs.
201+
# To avoid the fusion group from TE, we will disable the fuser here.
202+
for jit_freeze_or_not in [False, True]:
203+
test_model = TestModel().eval()
204+
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
205+
a = torch.rand(24, 128, 128)
206+
b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
207+
c = test_model(a, b)
208+
traced = torch.jit.trace(test_model, (a, b))
209+
if jit_freeze_or_not:
210+
traced = torch.jit.freeze(traced)
211+
for _ in range(3):
212+
c2 = traced(a, b)
213+
self.assertTrue(c.dtype, torch.float32)
214+
self.assertTrue(c2.dtype, torch.float32)
215+
traced_graph = traced.graph_for(a, b)
216+
self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes()))
217+
190218
class TestPyTorchOps(TestCase):
191219
def test_bernoulli(self):
192220
input = torch.rand(8, 8)

0 commit comments

Comments
 (0)