Skip to content

Commit 2e53e64

Browse files
authored
settings: set appropriate dot_precision default (#1184)
1 parent e75c434 commit 2e53e64

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

helion/runtime/settings.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch._environment import is_fbcode
1919

2020
from .. import exc
21+
from .._compat import supports_amd_cdna_tunables
2122
from ..autotuner.effort_profile import AutotuneEffort
2223
from ..autotuner.effort_profile import get_effort_profile
2324
from .ref_mode import RefMode
@@ -264,6 +265,23 @@ def _get_ref_mode() -> RefMode:
264265
return RefMode.EAGER if interpret else RefMode.OFF
265266

266267

268+
def _get_dot_precision() -> DotPrecision:
269+
"""
270+
Get the dot precision setting from TRITON_F32_DEFAULT environment variable.
271+
Defaults to 'tf32', 'ieee' if rocm and not CDNA.
272+
"""
273+
if torch.version.hip is not None:
274+
default_precision = "tf32" if supports_amd_cdna_tunables() else "ieee"
275+
else:
276+
default_precision = "tf32"
277+
278+
return _env_get_literal(
279+
"TRITON_F32_DEFAULT",
280+
cast("DotPrecision", default_precision),
281+
mapping={k: k for k in ("tf32", "tf32x3", "ieee")},
282+
)
283+
284+
267285
@dataclasses.dataclass
268286
class _Settings:
269287
# see __slots__ below for the doc strings that show up in help(Settings)
@@ -273,14 +291,7 @@ class _Settings:
273291
index_dtype: torch.dtype | None = dataclasses.field(
274292
default_factory=_get_index_dtype
275293
)
276-
dot_precision: DotPrecision = dataclasses.field(
277-
default_factory=functools.partial(
278-
_env_get_literal,
279-
"TRITON_F32_DEFAULT",
280-
cast("DotPrecision", "tf32"),
281-
mapping={k: k for k in ("tf32", "tf32x3", "ieee")},
282-
)
283-
)
294+
dot_precision: DotPrecision = dataclasses.field(default_factory=_get_dot_precision)
284295
static_shapes: bool = dataclasses.field(
285296
default_factory=functools.partial(_env_get_bool, "HELION_STATIC_SHAPES", True)
286297
)

0 commit comments

Comments
 (0)