Skip to content

Commit a6e1c2d

Browse files
authored
extend LLM checkpoint suffix for support .safetensors (#5021)
* extend LLM checkpoint suffix for support .safetensors * remove ignore pattern '*.safetensors'
1 parent 159c097 commit a6e1c2d

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

examples/gpu/llm/inference/run_generation_with_deepspeed.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from argparse import ArgumentParser
1010
from pathlib import Path
1111
import torch
12+
import re
1213

1314
import deepspeed
1415
from deepspeed.accelerator import get_accelerator
@@ -156,7 +157,7 @@ def get_repo_root(model_name_or_path):
156157
model_name_or_path,
157158
local_files_only=is_offline_mode(),
158159
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
159-
ignore_patterns=["*.safetensors", "*.msgpack", "*.h5"],
160+
ignore_patterns=["*.msgpack", "*.h5"],
160161
resume_download=True,
161162
)
162163

@@ -166,17 +167,23 @@ def get_repo_root(model_name_or_path):
166167
model_name_or_path,
167168
local_files_only=is_offline_mode(),
168169
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
169-
ignore_patterns=["*.safetensors", "*.msgpack", "*.h5"],
170+
ignore_patterns=["*.msgpack", "*.h5"],
170171
resume_download=True,
171172
)
172173

173174

174175
def get_checkpoint_files(model_name_or_path):
175176
cached_repo_dir = get_repo_root(model_name_or_path)
176177

177-
# extensions: .bin | .pt
178+
# extensions: .bin | .pt | .safetensors
178179
# creates a list of paths from all downloaded files in cache dir
179-
file_list = [str(entry) for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") if entry.is_file()]
180+
file_list = list()
181+
pattern_sample = re.compile(r'(.*).(safetensors|bin|pt)$')
182+
for entry in Path(cached_repo_dir).rglob("*"):
183+
match = re.match(pattern=pattern_sample, string=str(entry))
184+
if match:
185+
file_list.append(str(entry))
186+
180187
return file_list
181188

182189

0 commit comments

Comments
 (0)