Skip to content

Commit 2b2af7c

Browse files
authored
Guard against random.seed() when making job names unique
Differential Revision: D88241480 Pull Request resolved: #1170
1 parent 9016924 commit 2b2af7c

File tree

3 files changed

+71
-37
lines changed

3 files changed

+71
-37
lines changed

torchx/schedulers/ids.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
# pyre-strict
99

1010
import os
11-
import random
1211
import struct
1312

13+
1414
START_CANDIDATES: str = "bcdfghjklmnpqrstvwxz"
1515
END_CANDIDATES: str = START_CANDIDATES + "012345679"
1616

@@ -19,14 +19,19 @@ def make_unique(name: str, string_length: int = 0) -> str:
1919
"""
2020
Appends a unique 64-bit string to the input argument.
2121
22+
Note that the unique string pulls entropy from `/dev/urandom` hence is not
23+
affected by `random.seed()`
24+
25+
Args:
26+
name: the name string to unique-ify
27+
string_length: max length of the unique 64-bit string to append to the ``name``.
28+
Default is 0, which returns the length of a randomly generated 64-bit string (typically 11-14 characters long).
29+
2230
Returns:
23-
string in format $name-$unique_suffix
31+
string in format ``{name}-{unique_suffix}`
2432
"""
25-
return (
26-
f"{name}-{random_id()}"
27-
if string_length == 0
28-
else f"{name}-{get_len_random_id(string_length)}"
29-
)
33+
max_length = None if string_length == 0 else string_length
34+
return f"{name}-{random_id(max_length)}"
3035

3136

3237
def random_uint64() -> int:
@@ -36,13 +41,24 @@ def random_uint64() -> int:
3641
return struct.unpack("!Q", os.urandom(8))[0]
3742

3843

39-
def random_id() -> str:
44+
def random_id(max_length: int | None = None) -> str:
4045
"""
4146
Generates an alphanumeric string ID that matches the requirements from
4247
https://kubernetes.io/docs/concepts/overview/working-with-objects/names/
48+
49+
Note that the unique string pulls entropy from `/dev/urandom` hence is not
50+
affected by `random.seed()`
51+
52+
If ``max_length`` is provided, the returned ID will be at most that many characters long.
53+
4354
"""
55+
# If a max_length is provided and is non-positive, return empty string
56+
if max_length is not None and max_length <= 0:
57+
return ""
58+
4459
out = ""
4560
v = random_uint64()
61+
4662
while v > 0:
4763
if out == "":
4864
candidates = START_CANDIDATES
@@ -52,21 +68,9 @@ def random_id() -> str:
5268
char = v % len(candidates)
5369
v = v // len(candidates)
5470
out += candidates[char]
55-
return out
56-
57-
58-
def get_len_random_id(string_length: int) -> str:
59-
"""
60-
Generates an alphanumeric string ID that matches the requirements from
61-
https://kubernetes.io/docs/concepts/overview/working-with-objects/names/
62-
"""
63-
out = ""
64-
for i in range(string_length):
65-
if out == "":
66-
candidates = START_CANDIDATES
67-
else:
68-
candidates = END_CANDIDATES
6971

70-
out += random.choice(candidates)
72+
if max_length is not None and len(out) >= max_length:
73+
break
7174

75+
# NOTE: statistically the length of `out` is typically between 12-14 characters long
7276
return out

torchx/schedulers/test/ids_test.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,27 @@
88
# pyre-strict
99

1010

11+
import typing
1112
import unittest
13+
from contextlib import contextmanager
1214
from unittest.mock import MagicMock, patch
1315

14-
from torchx.schedulers.ids import (
15-
get_len_random_id,
16-
make_unique,
17-
random_id,
18-
random_uint64,
19-
)
16+
from torchx.schedulers.ids import make_unique, random_id, random_uint64
17+
18+
19+
@contextmanager
20+
def scoped_random_seed(seed: int) -> typing.Generator[None, None, None]:
21+
"""
22+
Temporarily set the random module's seed and restore its state afterward.
23+
"""
24+
import random
25+
26+
state = random.getstate()
27+
try:
28+
random.seed(seed)
29+
yield
30+
finally:
31+
random.setstate(state)
2032

2133

2234
class IdsTest(unittest.TestCase):
@@ -42,15 +54,34 @@ def test_random_id(self) -> None:
4254
self.assertIn(v[0], ALPHAS)
4355
self.assertGreater(len(v), 5)
4456

45-
def test_get_len_random_id(self) -> None:
46-
size = 6
47-
self.assertNotEqual(get_len_random_id(size), get_len_random_id(size))
48-
self.assertEqual(size, len(get_len_random_id(size)))
57+
def test_random_id_max_length(self) -> None:
58+
for max_length in range(6, 10):
59+
with self.subTest(max_length=max_length):
60+
self.assertLessEqual(len(random_id(max_length)), max_length)
61+
self.assertNotEqual(random_id(max_length), random_id(max_length))
62+
63+
def test_random_id_zero_max_length(self) -> None:
64+
self.assertEqual("", random_id(max_length=0))
4965

5066
@patch("os.urandom", return_value=bytes(range(8)))
5167
def test_random_id_seed(self, urandom: MagicMock) -> None:
5268
self.assertEqual(random_id(), "fzfjxlmln9")
69+
self.assertEqual(random_id(max_length=6), "fzfjxl")
5370

5471
@patch("os.urandom", return_value=bytes(range(8)))
5572
def test_make_unique_seed(self, urandom: MagicMock) -> None:
5673
self.assertEqual(make_unique("test"), "test-fzfjxlmln9")
74+
75+
def test_make_unique_not_affected_by_random_seed(self) -> None:
76+
# Seeding the Python random module should not affect make_unique(),
77+
# which relies on os.urandom for entropy.
78+
with scoped_random_seed(0):
79+
v1 = make_unique("test")
80+
81+
with scoped_random_seed(0):
82+
v2 = make_unique("test")
83+
84+
# Even with the same random seed, make_unique should produce different values.
85+
self.assertNotEqual(v1, v2)
86+
self.assertTrue(v1.startswith("test-"))
87+
self.assertTrue(v2.startswith("test-"))

torchx/schedulers/test/kubernetes_mcad_scheduler_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,20 +482,19 @@ def test_cleanup_str(self) -> None:
482482
self.assertEqual("abcd1234", cleanup_str("1234abcd1234"))
483483

484484
def test_get_unique_truncated_appid(self) -> None:
485-
scheduler = create_scheduler("test")
486485
app = _test_app()
487486
app.name = "abcde"
488-
self.assertEqual(20, len(get_unique_truncated_appid(app)))
487+
self.assertLessEqual(len(get_unique_truncated_appid(app)), 20)
489488
self.assertIn(app.name, get_unique_truncated_appid(app))
490489

491490
app.name = "abcdefghijklmnopqrstuvwxyz012345678910111213141516"
492-
self.assertEqual(56, len(get_unique_truncated_appid(app)))
491+
self.assertLessEqual(len(get_unique_truncated_appid(app)), 56)
493492
self.assertIn(app.name, get_unique_truncated_appid(app))
494493

495494
app.name = (
496495
"abcdefghijklmnopqrstuvwxyz012345678910111213141516171819202122232425"
497496
)
498-
self.assertEqual(59, len(get_unique_truncated_appid(app)))
497+
self.assertLessEqual(len(get_unique_truncated_appid(app)), 59)
499498
self.assertIn(
500499
"abcdefghijklmnopqrstuvwxyz01234567891011121314151617181",
501500
get_unique_truncated_appid(app),

0 commit comments

Comments
 (0)