1818from torch ._environment import is_fbcode
1919
2020from .. import exc
21+ from .._compat import supports_amd_cdna_tunables
2122from ..autotuner .effort_profile import AutotuneEffort
2223from ..autotuner .effort_profile import get_effort_profile
2324from .ref_mode import RefMode
@@ -264,6 +265,23 @@ def _get_ref_mode() -> RefMode:
264265 return RefMode .EAGER if interpret else RefMode .OFF
265266
266267
268+ def _get_dot_precision () -> DotPrecision :
269+ """
270+ Get the dot precision setting from TRITON_F32_DEFAULT environment variable.
271+ Defaults to 'tf32', 'ieee' if rocm and not CDNA.
272+ """
273+ if torch .version .hip is not None :
274+ default_precision = "tf32" if supports_amd_cdna_tunables () else "ieee"
275+ else :
276+ default_precision = "tf32"
277+
278+ return _env_get_literal (
279+ "TRITON_F32_DEFAULT" ,
280+ cast ("DotPrecision" , default_precision ),
281+ mapping = {k : k for k in ("tf32" , "tf32x3" , "ieee" )},
282+ )
283+
284+
267285@dataclasses .dataclass
268286class _Settings :
269287 # see __slots__ below for the doc strings that show up in help(Settings)
@@ -273,14 +291,7 @@ class _Settings:
273291 index_dtype : torch .dtype | None = dataclasses .field (
274292 default_factory = _get_index_dtype
275293 )
276- dot_precision : DotPrecision = dataclasses .field (
277- default_factory = functools .partial (
278- _env_get_literal ,
279- "TRITON_F32_DEFAULT" ,
280- cast ("DotPrecision" , "tf32" ),
281- mapping = {k : k for k in ("tf32" , "tf32x3" , "ieee" )},
282- )
283- )
294+ dot_precision : DotPrecision = dataclasses .field (default_factory = _get_dot_precision )
284295 static_shapes : bool = dataclasses .field (
285296 default_factory = functools .partial (_env_get_bool , "HELION_STATIC_SHAPES" , True )
286297 )
0 commit comments