Skip to content

Commit 887697b

Browse files
author
Ubuntu
committed
trained with squad dataset
1 parent dbe3679 commit 887697b

File tree

2 files changed

+44
-17
lines changed

2 files changed

+44
-17
lines changed

src/korea_travel_guide/inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,16 @@ def main():
5151
inf_args = parse_args()
5252

5353
# set device
54-
device = 0 if torch.cuda.is_available() else -1
54+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5555
logger.info(f"Using device: {device}")
5656

5757
# load tokenizer + model
5858
tok = AutoTokenizer.from_pretrained("facebook/bart-base")
5959
base_model = build_base_model()
60-
if training_args.use_sdpa_attention:
60+
if inf_args.use_sdpa_attention:
6161
base_model.config.attn_implementation = "sdpa"
6262
model = load_peft_model_for_inference(base_model)
63+
model.to(device)
6364

6465
# tokenize & format depending on mode
6566
if inf_args.mode == "test":
@@ -102,7 +103,8 @@ def main():
102103
return_tensors="pt",
103104
padding=True,
104105
truncation=True,
105-
).to(model.device)
106+
)
107+
enc = {k: v.to(device) for k, v in enc.items()}
106108

107109
# fast batched generate
108110
out = model.generate(

src/korea_travel_guide/train.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ class DataArgs:
3434
train_sample: bool = field(
3535
default=False, metadata={"help": "Use the mini CSV for smoke tests if True."}
3636
)
37-
37+
use_squad: bool = field(
38+
default=False, metadata={"help": "If True, ignore CSVs and load SQuAD instead."}
39+
)
3840

3941
# training & LoRA extras — extend HF’s own Seq2SeqTrainingArguments
4042
@dataclass
@@ -126,19 +128,42 @@ def main() -> None:
126128
logger.info(f"Set seed: {training_args.seed}")
127129

128130
# ---------- Data Preprocessing ----------
129-
# load and tokenize dataset
130-
# load CSVs
131-
data_files = {
132-
"train": str(
133-
data_args.train_sample_file
134-
if data_args.train_sample
135-
else data_args.train_file
136-
),
137-
"validation": str(data_args.validation_file),
138-
"test": str(data_args.test_file),
139-
}
140-
141-
ds = load_dataset("csv", data_files=data_files)
131+
# load either CSVs or SQuAD for a quick pipeline sanity check
132+
if data_args.use_squad:
133+
# 1) pull down SQuAD
134+
raw = load_dataset("squad")
135+
# 2) map to simple Q/A pairs (first answer only)
136+
def to_qa(ex):
137+
return {
138+
"question": ex["question"],
139+
"answer": ex["answers"]["text"][0]
140+
}
141+
ds = raw.map(to_qa, remove_columns=raw["train"].column_names)
142+
else:
143+
# load from your processed CSVs
144+
data_files = {
145+
"train": str(
146+
data_args.train_sample_file
147+
if data_args.train_sample
148+
else data_args.train_file
149+
),
150+
"validation": str(data_args.validation_file),
151+
"test": str(data_args.test_file),
152+
}
153+
ds = load_dataset("csv", data_files=data_files)
154+
# # load and tokenize dataset
155+
# # load CSVs
156+
# data_files = {
157+
# "train": str(
158+
# data_args.train_sample_file
159+
# if data_args.train_sample
160+
# else data_args.train_file
161+
# ),
162+
# "validation": str(data_args.validation_file),
163+
# "test": str(data_args.test_file),
164+
# }
165+
166+
# ds = load_dataset("csv", data_files=data_files)
142167
ds_tok, tok = tokenize_and_format(ds)
143168

144169
# initialize base model and LoRA

0 commit comments

Comments
 (0)