Skip to content

Commit d69caeb

Browse files
committed
increased dataset size to 3500 and optimized for training on T4 GPU
1 parent f0d433a commit d69caeb

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

data/subreddit_size_map.json

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
{
2-
"askscience": 500,
3-
"AskHistorians": 500
2+
"ExplainLikeImFive": 500,
3+
"AskPhysics": 500,
4+
"AskSocialScience": 500,
5+
"AskDocs": 500,
6+
"askscience": 1500,
7+
"AskHistorians": 1500
48
}

src/korea_travel_guide/data.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,22 +192,27 @@ def split_and_save(df, out_dir: Union[str, Path]):
192192
def tokenize_and_format(
193193
ds: DatasetDict,
194194
checkpoint: str = "facebook/bart-base",
195-
max_input_length: int = 1024,
196-
max_target_length: int = 1024,
195+
max_input_length: int = 224, # max 1024
196+
max_target_length: int = 800, # max 1024
197197
) -> Tuple[DatasetDict, AutoTokenizer]:
198198
tok = AutoTokenizer.from_pretrained(checkpoint)
199199

200200
def _preprocess_batch(examples):
201201
# tokenize inputs
202+
tok.truncation_side = "right"
202203
model_inputs = tok(
203-
examples["question"], max_length=max_input_length, truncation=True
204+
examples["question"],
205+
max_length=max_input_length,
206+
truncation=True,
204207
)
205208
# tokenize targets in “target” mode
209+
tok.truncation_side = "left"
206210
labels = tok(
207211
text_target=examples["answer"],
208212
max_length=max_target_length,
209213
truncation=True,
210214
)
215+
tok.truncation_side = "right" # reset for safety
211216

212217
model_inputs["labels"] = labels["input_ids"]
213218
return model_inputs

src/korea_travel_guide/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
5151
lr_scheduler_type: str = "linear"
5252
warmup_ratio: float = 0.05
5353
num_train_epochs: int = 3
54-
5554
per_device_train_batch_size: int = 8
5655
per_device_eval_batch_size: int = 16
5756
max_grad_norm: float = 0.5
5857
label_smoothing_factor: float = 0.1
5958
weight_decay: float = 0.01
59+
generation_max_length: int = 384
6060
save_total_limit: int = 2
6161
fp16: bool = True
6262
predict_with_generate: bool = True
@@ -195,6 +195,7 @@ def main() -> None:
195195
model=lora_model,
196196
padding="longest", # or "max_length"
197197
label_pad_token_id=-100,
198+
pad_to_multiple_of=8, # tensor-core friendly
198199
)
199200

200201
# initialize trainer & train

0 commit comments

Comments
 (0)