Skip to content

Commit 81dc268

Browse files
author
Ubuntu
committed
updated sweep.yaml with correct syntax
1 parent 785fd7b commit 81dc268

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

src/korea_travel_guide/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def split_and_save(df, out_dir: Union[str, Path]):
192192
def tokenize_and_format(
193193
ds: DatasetDict,
194194
checkpoint: str = "facebook/bart-base",
195-
max_input_length: int = 224, # max 1024
195+
max_input_length: int = 1024, # max 1024 224
196196
max_target_length: int = 800, # max 1024
197197
) -> Tuple[DatasetDict, AutoTokenizer]:
198198
tok = AutoTokenizer.from_pretrained(checkpoint)

src/korea_travel_guide/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def main() -> None:
173173
eval_dataset=ds_tok["validation"],
174174
tokenizer=tok,
175175
data_collator=data_collator,
176-
compute_metrics=build_compute_metrics(tok),
176+
# compute_metrics=build_compute_metrics(tok),
177177
)
178178

179179
trainer.train()

sweep.yaml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
# sweep spec for Korea-Travel-Guide BART
22
program: scripts/train.py
3+
4+
project: bart-base-korea-travel-guide-lora
5+
entity: codinglabsong-keio-jp
6+
37
method: bayes # {grid | random | bayes}
4-
run_cap: 5 # sweep run limit
8+
run_cap: 10 # sweep run limit
59

610
metric: # what to optimise
7-
name: eval/rougeL # must match the key in evaluation.compute_metrics returns
11+
name: eval/loss # must match the key in evaluation.compute_metrics returns
812
goal: maximize
913

1014
parameters:
1115
learning_rate:
12-
min: 1e-5
13-
max: 1e-3
14-
distribution: log_uniform
16+
min: 0.00001
17+
max: 0.001
18+
distribution: log_uniform_values
1519
num_train_epochs:
16-
values: 1
20+
values: [1]
1721
peft_rank:
1822
values: [4, 8, 16]
1923
train_sample:
20-
values: True
24+
values: [True]
2125

2226
early_terminate:
2327
type: hyperband

0 commit comments

Comments
 (0)