Skip to content

Commit cb86ae2

Browse files
Yijia Jinjetstream authors
authored andcommitted
Refactor: Execute Maxtext inference as a module via python3 -m
Updated the recommended execution command for the inference script from `python3 MaxText/decode.py` to `python3 -m MaxText.decode`. This change utilizes Python's standard module execution mechanism (`-m`), ensuring the script runs within the proper package context (`MaxText`). This improves consistency and robustness of import resolution. No changes were made to the source code itself. PiperOrigin-RevId: 748949600
1 parent c856af5 commit cb86ae2

File tree

7 files changed

+23
-23
lines changed

7 files changed

+23
-23
lines changed

benchmarks/mlperf/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-70b-chat
8282
```
8383
export TOKENIZER_PATH=maxtext/assets/tokenizer.llama2
8484
cd maxtext && \
85-
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
85+
python3 -m MaxText.decode MaxText/configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
8686
```
8787

8888
Your checkpoint is generated at `$SAVE_QUANT_PARAMS_PATH`. This is used to set `load_parameters_path` param below in `MAXENGINE_ARGS` env variable.
@@ -96,7 +96,7 @@ huggingface-cli login
9696
Start Jetstream server in a terminal.
9797
```
9898
cd ~/maxtext
99-
python MaxText/maxengine_server.py \
99+
python3 -m MaxText.maxengine_server \
100100
MaxText/configs/base.yml \
101101
tokenizer_path=assets/tokenizer.llama2 \
102102
load_parameters_path="gs://msingh-bkt/checkpoints/quant_llama2-70b-chat/mlperf_070924/int8_" \

benchmarks/mlperf/scripts/tpu_script.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ copy_relevant_files() {
160160

161161
# # source .env/bin/activate
162162
# your_run_name=jwyang_bs1_llama7b
163-
# python MaxText/inference_microbenchmark.py \
163+
# python3 -m MaxText.inference_microbenchmark \
164164
# MaxText/configs/base.yml \
165165
# base_output_directory=gs://jwyang-data/maxtext-llama2-7b/microbenchmark \
166166
# run_name=${your_run_name} \
@@ -192,7 +192,7 @@ export load_parameters_path_chat=gs://jwyang-runner-maxtext-logs/llama2-7b_unsca
192192
export load_parameters_path=gs://jwyang-runner-maxtext-logs/llama2-7b_unscanned_chkpt_2024-04-26-19-40/checkpoints/0/items
193193
export load_parameters_path_chat_quantized=gs://jwyang-data/llama7b-chat-quantized-fixed/0/items
194194

195-
python MaxText/maxengine_server.py \
195+
python3 -m MaxText.maxengine_server \
196196
MaxText/configs/base.yml \
197197
base_output_directory=gs://jwyang-data/maxtext-llama2-7b/microbenchmark \
198198
load_parameters_path=${load_parameters_path_chat} \
@@ -244,7 +244,7 @@ export load_parameters_path=gs://runner-maxtext-logs/2024-05-16-23-59/unscanned_
244244

245245
export experiment_time=$(date +%Y-%m-%d-%H-%M)
246246
echo "export experiment_time=${experiment_time}"
247-
python MaxText/maxengine_server.py \
247+
python3 -m MaxText.maxengine_server \
248248
MaxText/configs/base.yml \
249249
base_output_directory=gs://morgandu-tpu/maxtext-logs/microbenchmark/${experiment_time} \
250250
model_name=llama2-13b \
@@ -269,7 +269,7 @@ python MaxText/maxengine_server.py \
269269
per_device_batch_size=1
270270

271271

272-
python MaxText/inference_microbenchmark.py \
272+
python3 -m MaxText.inference_microbenchmark \
273273
MaxText/configs/base.yml \
274274
base_output_directory=gs://morgandu-tpu/maxtext-logs/microbenchmark/${experiment_time} \
275275
model_name=llama2-13b \
@@ -298,7 +298,7 @@ python MaxText/inference_microbenchmark.py \
298298
# # LLaMA2-70B commands
299299
# # source .env/bin/activate
300300
# your_run_name=jwyang_bs1_llama70b
301-
# python MaxText/inference_microbenchmark.py \
301+
# python3 -m MaxText.inference_microbenchmark \
302302
# MaxText/configs/base.yml \
303303
# base_output_directory=gs://jwyang-data/maxtext-llama2-70b/microbenchmark \
304304
# run_name=${your_run_name} \
@@ -328,7 +328,7 @@ export per_device_batch_size=1
328328
export prefill_length=16
329329
export target_length=32
330330

331-
python MaxText/maxengine_server.py \
331+
python3 -m MaxText.maxengine_server \
332332
MaxText/configs/base.yml \
333333
base_output_directory=gs://jwyang-data/maxtext-llama2-70b/microbenchmark \
334334
run_name=$(date +%Y-%m-%d-%H-%M) \

docs/observability-prometheus-metrics-in-jetstream-server.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ export PER_DEVICE_BATCH_SIZE=11
2323
export PROMETHEUS_PORT=9090
2424

2525
cd ~/maxtext
26-
python MaxText/maxengine_server.py \
26+
python3 -m MaxText.maxengine_server \
2727
MaxText/configs/base.yml \
2828
tokenizer_path=${TOKENIZER_PATH} \
2929
load_parameters_path=${LOAD_PARAMETERS_PATH} \

docs/online-inference-with-maxtext-engine.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ export PER_DEVICE_BATCH_SIZE=4
157157

158158
```bash
159159
cd ~/maxtext
160-
python MaxText/maxengine_server.py \
160+
python3 -m MaxText.maxengine_server \
161161
MaxText/configs/base.yml \
162162
tokenizer_path=${TOKENIZER_PATH} \
163163
load_parameters_path=${LOAD_PARAMETERS_PATH} \
@@ -225,12 +225,12 @@ There are several different quantization configurations to choose from:
225225

226226
#### int8 DRQ quantized checkpoint
227227
```bash
228-
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
228+
python3 -m MaxText.decode MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
229229
```
230230

231231
#### Weights-only int8 quantized checkpoint
232232
```bash
233-
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8w save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
233+
python3 -m MaxText.decode MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8w save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
234234
```
235235

236236
#### Mixed precision weight-only quantized checkpoint
@@ -247,7 +247,7 @@ First, update the mixed precision config file (`MaxText/configs/quantization/mp_
247247
```
248248
Then run the following command:
249249
```bash
250-
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=intmp
250+
python3 -m MaxText.decode MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=intmp
251251
quant_cfg_path=configs/quantization/mp_scale.json save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
252252
```
253253

@@ -291,7 +291,7 @@ export QUANTIZE_KVCACHE=False
291291
export PER_DEVICE_BATCH_SIZE=12
292292

293293
cd ~/maxtext
294-
python MaxText/maxengine_server.py \
294+
python3 -m MaxText.maxengine_server \
295295
MaxText/configs/base.yml \
296296
tokenizer_path=${TOKENIZER_PATH} \
297297
load_parameters_path=${LOAD_PARAMETERS_PATH} \
@@ -311,7 +311,7 @@ python MaxText/maxengine_server.py \
311311

312312
For the mixed precision quantized model
313313
```bash
314-
python MaxText/maxengine_server.py \
314+
python3 -m MaxText.maxengine_server \
315315
MaxText/configs/base.yml \
316316
tokenizer_path=${TOKENIZER_PATH} \
317317
load_parameters_path=${LOAD_PARAMETERS_PATH} \

docs/profiling-with-jax-profiler-and-tensorboard.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ export ENABLE_JAX_PROFILER=true
3939
export JAX_PROFILER_PORT=9999
4040

4141
cd ~/maxtext
42-
python MaxText/maxengine_server.py \
42+
python3 -m MaxText.maxengine_server \
4343
MaxText/configs/base.yml \
4444
tokenizer_path=${TOKENIZER_PATH} \
4545
load_parameters_path=${LOAD_PARAMETERS_PATH} \

jetstream/tools/maxtext/model_ckpt_conversion.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCAT
5555

5656
# Convert model checkpoints to MaxText compatible checkpoints.
5757
if [ "$MODEL" == "gemma" ]; then
58-
CONVERT_CKPT_SCRIPT="convert_gemma_chkpt.py"
59-
JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
58+
CONVERT_CKPT_SCRIPT="convert_gemma_chkpt"
59+
JAX_PLATFORMS=cpu python3 -m MaxText.${CONVERT_CKPT_SCRIPT} \
6060
--base_model_path ${CHKPT_BUCKET} \
6161
--maxtext_model_path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \
6262
--model_size ${MODEL_VARIATION}
@@ -87,14 +87,14 @@ else
8787
lora_local_path=${LORA_INPUT_ADAPTERS_PATH}
8888
fi
8989

90-
JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
90+
JAX_PLATFORMS=cpu python3 -m MaxText.${CONVERT_CKPT_SCRIPT} \
9191
--base-model-path ${tmp_ckpt_path}${directory_substring} \
9292
--maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \
9393
--model-size ${MODEL_NAME} \
9494
--lora-input-adapters-path ${lora_local_path} \
9595
--huggingface-checkpoint ${HUGGING_FACE_CHECKPOINT}
9696
else
97-
JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
97+
JAX_PLATFORMS=cpu python3 -m MaxText.${CONVERT_CKPT_SCRIPT} \
9898
--base-model-path ${tmp_ckpt_path}${directory_substring} \
9999
--maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \
100100
--model-size ${MODEL_NAME} \
@@ -111,7 +111,7 @@ export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}
111111
export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx}
112112

113113
if [[ ! -z "${LORA_INPUT_ADAPTERS_PATH}" ]]; then
114-
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \
114+
JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint \
115115
MaxText/configs/base.yml \
116116
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
117117
load_parameters_path=${SCANNED_CKPT_PATH}/base/0/items \
@@ -121,7 +121,7 @@ if [[ ! -z "${LORA_INPUT_ADAPTERS_PATH}" ]]; then
121121
force_unroll=true
122122
echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints"
123123
else
124-
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \
124+
JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint \
125125
MaxText/configs/base.yml \
126126
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
127127
load_parameters_path=${SCANNED_CKPT_PATH}/0/items \

jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ export AQT_CKPT=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/100/items
7070
# Note that the `AQT_CKPT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
7171
export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx}
7272

73-
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \
73+
JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint \
7474
MaxText/configs/base.yml \
7575
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
7676
load_parameters_path=${AQT_CKPT} \

0 commit comments

Comments
 (0)