Skip to content

Commit 76f3790

Browse files
author
Ubuntu
committed
removed flash attention to sdpa, which is supported in T4 GPU
1 parent d69caeb commit 76f3790

File tree

2 files changed

+25
-70
lines changed

2 files changed

+25
-70
lines changed

src/korea_travel_guide/inference.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class InferenceArgs:
3434
default_factory=list,
3535
metadata={"help": "One or more input texts for `predict` mode."},
3636
)
37-
use_flash_attention: bool = field(
38-
default=True, metadata={"help": "Enable Flash Attention v1 (via sdp_kernel)."}
37+
use_sdpa_attention: bool = field(
38+
default=True, metadata={"help": "Enable Sdpa for mem-efficient kernel."}
3939
)
4040

4141

@@ -57,18 +57,10 @@ def main():
5757
# load tokenizer + model
5858
tok = AutoTokenizer.from_pretrained("facebook/bart-base")
5959
base_model = build_base_model()
60+
if training_args.use_sdpa_attention:
61+
base_model.config.attn_implementation = "sdpa"
6062
model = load_peft_model_for_inference(base_model)
6163

62-
# prepare Flash‐Attention context
63-
if inf_args.use_flash_attention and device >= 0:
64-
logger.info("Using flash attention")
65-
ctx = torch.backends.cuda.sdp_kernel(
66-
enable_flash=True, enable_math=True, enable_mem_efficient=True
67-
)
68-
else:
69-
logger.info("Skipping flash attention v1")
70-
ctx = nullcontext()
71-
7264
# tokenize & format depending on mode
7365
if inf_args.mode == "test":
7466
# load dataset
@@ -97,8 +89,7 @@ def main():
9789
compute_metrics=build_compute_metrics(tok),
9890
)
9991

100-
with ctx:
101-
pred_output = trainer.predict(ds_tok["test"])
92+
pred_output = trainer.predict(ds_tok["test"])
10293
metrics = pred_output.metrics
10394
logger.info(f"Test metrics: {metrics}")
10495

@@ -113,16 +104,15 @@ def main():
113104
truncation=True,
114105
).to(model.device)
115106

116-
with ctx:
117-
# fast batched generate
118-
out = model.generate(
119-
**enc,
120-
max_length=512,
121-
num_beams=5,
122-
early_stopping=True,
123-
length_penalty=1.0,
124-
repetition_penalty=1.1,
125-
)
107+
# fast batched generate
108+
out = model.generate(
109+
**enc,
110+
max_length=512,
111+
num_beams=5,
112+
early_stopping=True,
113+
length_penalty=1.0,
114+
repetition_penalty=1.1,
115+
)
126116

127117
decoded = tok.batch_decode(out, skip_special_tokens=True)
128118
for inp, pred in zip(inf_args.texts, decoded):

src/korea_travel_guide/train.py

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
4646
)
4747
eval_strategy: str = "epoch"
4848
save_strategy: str = "epoch"
49-
logging_steps: int = 5
49+
logging_steps: int = 50
5050
learning_rate: float = 1e-4
5151
lr_scheduler_type: str = "linear"
5252
warmup_ratio: float = 0.05
53-
num_train_epochs: int = 3
54-
per_device_train_batch_size: int = 8
55-
per_device_eval_batch_size: int = 16
53+
num_train_epochs: int = 8
54+
per_device_train_batch_size: int = 16
55+
per_device_eval_batch_size: int = 32
5656
max_grad_norm: float = 0.5
5757
label_smoothing_factor: float = 0.1
58-
weight_decay: float = 0.01
58+
# weight_decay: float = 0.01
5959
generation_max_length: int = 384
6060
save_total_limit: int = 2
6161
fp16: bool = True
@@ -76,8 +76,8 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
7676
default=False,
7777
metadata={"help": "If True, run the test-split evaluation after training."},
7878
)
79-
use_flash_attention: bool = field(
80-
default=True, metadata={"help": "Whether to enable Flash Attention v1."}
79+
use_sdpa_attention: bool = field(
80+
default=True, metadata={"help": "Enable Sdpa for mem-efficient kernel.}
8181
)
8282

8383

@@ -143,6 +143,8 @@ def main() -> None:
143143

144144
# initialize base model and LoRA
145145
base_model = build_base_model()
146+
if training_args.use_sdpa_attention:
147+
base_model.config.attn_implementation = "sdpa"
146148
logger.info(
147149
f"Base model trainable params:\n{print_trainable_parameters(base_model)}"
148150
)
@@ -153,42 +155,7 @@ def main() -> None:
153155
f"LoRA model (peft_rank={training_args.peft_rank}, lora_alpha={training_args.lora_alpha}) trainable params:\n{print_trainable_parameters(lora_model)}"
154156
)
155157

156-
# from torch.utils.data import DataLoader
157-
158-
# data_collator = DataCollatorForSeq2Seq(
159-
# tok,
160-
# model=lora_model,
161-
# padding="longest",
162-
# label_pad_token_id=-100,
163-
# )
164-
165-
# batch = next(iter(DataLoader(ds_tok["train"], batch_size=2, collate_fn=data_collator )))
166-
# # 1) decode inputs normally
167-
# print("INPUTS:")
168-
# print(tok.batch_decode(batch["input_ids"], skip_special_tokens=True))
169-
170-
# # 2) map -100 → pad_token_id before decoding labels
171-
# labels = batch["labels"].detach().cpu().numpy()
172-
# labels = np.where(labels != -100, labels, tok.pad_token_id)
173-
174-
# print("LABELS:")
175-
# print(tok.batch_decode(labels, skip_special_tokens=True))
176-
177-
# import sys
178-
# sys.exit()
179-
180158
# ---------- Train ----------
181-
# toggle flash attention
182-
if training_args.use_flash_attention:
183-
logger.info("Using flash attention")
184-
ctx = torch.backends.cuda.sdp_kernel(
185-
enable_flash=True, enable_math=True, enable_mem_efficient=True
186-
)
187-
else:
188-
logger.info("Skipping flash attention")
189-
ctx = nullcontext()
190-
# ctx = nullcontext()
191-
192159
# data collator: dynamic padding per batch
193160
data_collator = DataCollatorForSeq2Seq(
194161
tok,
@@ -209,15 +176,13 @@ def main() -> None:
209176
compute_metrics=build_compute_metrics(tok),
210177
)
211178

212-
with ctx:
213-
trainer.train()
179+
trainer.train()
214180

215181
# ---------- Save, Test or Push ----------
216182
# evaluate test
217183
if training_args.run_test:
218184
logger.info("Running final test-set evaluation...")
219-
with ctx:
220-
metrics = trainer.evaluate(ds_tok["test"])
185+
metrics = trainer.evaluate(ds_tok["test"])
221186
logger.info(f"Test metrics:\n{metrics}")
222187
else:
223188
logger.info("Skipping test evaluation.")

0 commit comments

Comments
 (0)