Skip to content

Commit ce53cc0

Browse files
[V5] Return a BatchEncoding dict from apply_chat_template by default again (#42567)
* Flip the default return type for `apply_chat_template` to match the underlying tokenizer * Remove test_tokenization_for_chat tests, which no longer do anything useful * Remove test_tokenization_for_chat tests, which no longer do anything useful * Fix test_encode_message tests * Fix test_encode_message tests * nit fix * Trigger tests * Remove test_tokenization_for_chat * make fixup * Add a little test to make sure that doesn't happen again * make fixup
1 parent a3e2d54 commit ce53cc0

File tree

7 files changed

+18
-136
lines changed

7 files changed

+18
-136
lines changed

src/transformers/tokenization_utils_base.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3195,7 +3195,7 @@ def apply_chat_template(
31953195
truncation: bool = False,
31963196
max_length: Optional[int] = None,
31973197
return_tensors: Optional[Union[str, TensorType]] = None,
3198-
return_dict: bool = False,
3198+
return_dict: bool = True,
31993199
return_assistant_tokens_mask: bool = False,
32003200
tokenizer_kwargs: Optional[dict[str, Any]] = None,
32013201
**kwargs,
@@ -3268,14 +3268,11 @@ def apply_chat_template(
32683268
set, will return a dict of tokenizer outputs instead.
32693269
"""
32703270

3271-
if return_dict and not tokenize:
3272-
raise ValueError(
3273-
"`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
3274-
"of tokenizer outputs to return."
3275-
)
3271+
if not tokenize:
3272+
return_dict = False # dicts are only returned by the tokenizer anyway
32763273

3277-
if return_assistant_tokens_mask and not return_dict:
3278-
raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`")
3274+
if return_assistant_tokens_mask and not (return_dict and tokenize):
3275+
raise ValueError("`return_assistant_tokens_mask=True` requires `return_dict=True` and `tokenize=True`")
32793276

32803277
if tokenizer_kwargs is None:
32813278
tokenizer_kwargs = {}
@@ -3390,13 +3387,17 @@ def encode_message_with_chat_template(
33903387
)
33913388

33923389
if conversation_history is None or len(conversation_history) == 0:
3393-
return self.apply_chat_template([message], add_generation_prompt=False, tokenize=True, **kwargs)
3390+
return self.apply_chat_template(
3391+
[message], add_generation_prompt=False, tokenize=True, return_dict=False, **kwargs
3392+
)
33943393

33953394
conversation = conversation_history + [message]
3396-
tokens = self.apply_chat_template(conversation, add_generation_prompt=False, tokenize=True, **kwargs)
3395+
tokens = self.apply_chat_template(
3396+
conversation, add_generation_prompt=False, tokenize=True, return_dict=False, **kwargs
3397+
)
33973398

33983399
prefix_tokens = self.apply_chat_template(
3399-
conversation_history, add_generation_prompt=False, tokenize=True, **kwargs
3400+
conversation_history, add_generation_prompt=False, tokenize=True, return_dict=False, **kwargs
34003401
)
34013402
# It's possible that the prefix tokens are not a prefix of the full list of tokens.
34023403
# For example, if the prefix is `<s>User: Hi` and the full conversation is `<s>User: Hi</s><s>Assistant: Hello`.

tests/models/blenderbot/test_tokenization_blenderbot.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,3 @@ def test_pretokenized_inputs(self, *args, **kwargs):
2121
# The issue is that when you have a sequence with leading spaces, splitting it
2222
# with .split() loses the leading spaces, so the tokenization results differ
2323
pass
24-
25-
def test_tokenization_for_chat(self):
26-
tok = self.get_tokenizer()
27-
test_chats = [
28-
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
29-
[
30-
{"role": "system", "content": "You are a helpful chatbot."},
31-
{"role": "user", "content": "Hello!"},
32-
{"role": "assistant", "content": "Nice to meet you."},
33-
],
34-
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
35-
]
36-
tokenized_chats = [tok.apply_chat_template(test_chat) for test_chat in test_chats]
37-
expected_tokens = [
38-
[553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 2],
39-
[553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 228, 3490, 287, 2273, 304, 21, 2],
40-
[3490, 287, 2273, 304, 21, 228, 228, 6950, 8, 2],
41-
]
42-
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
43-
self.assertListEqual(tokenized_chat, expected_tokens)

tests/models/bloom/test_tokenization_bloom.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from datasets import load_dataset
1818

1919
from transformers import TokenizersBackend
20-
from transformers.testing_utils import require_jinja, require_tokenizers, slow
20+
from transformers.testing_utils import require_tokenizers, slow
2121

2222
from ...test_tokenization_common import TokenizerTesterMixin
2323

@@ -129,28 +129,6 @@ def test_encodings_from_xnli_dataset(self):
129129
predicted_text = [tokenizer.decode(x, clean_up_tokenization_spaces=False) for x in output_tokens]
130130
self.assertListEqual(predicted_text, input_text)
131131

132-
@require_jinja
133-
def test_tokenization_for_chat(self):
134-
tokenizer = self.get_tokenizer()
135-
tokenizer.chat_template = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}"
136-
test_chats = [
137-
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
138-
[
139-
{"role": "system", "content": "You are a helpful chatbot."},
140-
{"role": "user", "content": "Hello!"},
141-
{"role": "assistant", "content": "Nice to meet you."},
142-
],
143-
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
144-
]
145-
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
146-
expected_tokens = [
147-
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2],
148-
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2, 229126, 427, 11890, 1152, 17, 2],
149-
[229126, 427, 11890, 1152, 17, 2, 59414, 4, 2],
150-
]
151-
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
152-
self.assertListEqual(tokenized_chat, expected_tokens)
153-
154132
def test_add_prefix_space_fast(self):
155133
tokenizer_w_prefix = self.get_tokenizer(add_prefix_space=True)
156134
tokenizer_wo_prefix = self.get_tokenizer(add_prefix_space=False)

tests/models/cohere/test_tokenization_cohere.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -73,32 +73,6 @@ def test_pretrained_model_lists(self):
7373
self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
7474
self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
7575

76-
@require_jinja
77-
def test_tokenization_for_chat(self):
78-
tokenizer = self.get_tokenizer()
79-
test_chats = [
80-
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
81-
[
82-
{"role": "system", "content": "You are a helpful chatbot."},
83-
{"role": "user", "content": "Hello!"},
84-
{"role": "assistant", "content": "Nice to meet you."},
85-
],
86-
]
87-
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
88-
# fmt: off
89-
expected_tokens = [
90-
[5, 36, 99, 59, 60, 41, 58, 60, 71, 55, 46, 71, 60, 61, 58, 54, 71, 60, 55, 51, 45, 54, 99, 38, 36, 99, 59, 65, 59, 60, 45, 53, 71, 60, 55, 51, 45, 54, 99, 38, 65, 243, 394, 204, 336, 84, 88, 887, 374, 216, 74, 286, 22, 8, 36, 99, 59, 60, 41, 58, 60, 71, 55, 46, 71, 60, 61, 58, 54, 71, 60, 55, 51, 45, 54, 99, 38, 36, 99, 61, 59, 45, 58, 71, 60, 55, 51, 45, 54, 99, 38, 48, 420, 87, 9, 8],
91-
[5, 36, 99, 59, 60, 41, 58, 60, 71, 55, 46, 71, 60, 61, 58, 54, 71, 60, 55, 51, 45, 54, 99, 38, 36, 99, 59, 65,
92-
59, 60, 45, 53, 71, 60, 55, 51, 45, 54, 99, 38, 65, 243, 394, 204, 336, 84, 88, 887, 374, 216, 74, 286, 22, 8,
93-
36, 99, 59, 60, 41, 58, 60, 71, 55, 46, 71, 60, 61, 58, 54, 71, 60, 55, 51, 45, 54, 99, 38, 36, 99, 61, 59,
94-
45, 58, 71, 60, 55, 51, 45, 54, 99, 38, 48, 420, 87, 9, 8, 36, 99, 59, 60, 41, 58, 60, 71, 55, 46, 71, 60, 61,
95-
58, 54, 71, 60, 55, 51, 45, 54, 99, 38, 36, 99, 43, 48, 41, 60, 42, 55, 60, 71, 60, 55, 51, 45, 54, 99, 38,
96-
54, 567, 235, 693, 276, 411, 243, 22, 8]
97-
]
98-
# fmt: on
99-
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
100-
self.assertListEqual(tokenized_chat, expected_tokens)
101-
10276
@require_jinja
10377
def test_tokenization_for_tool_use(self):
10478
tokenizer = self.get_tokenizer()

tests/models/gpt2/test_tokenization_gpt2.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import unittest
1717

1818
from transformers import AutoTokenizer, GPT2Tokenizer
19-
from transformers.testing_utils import require_jinja, require_tiktoken, require_tokenizers
19+
from transformers.testing_utils import require_tiktoken, require_tokenizers
2020

2121
from ...test_tokenization_common import TokenizerTesterMixin
2222

@@ -67,26 +67,6 @@ def test_special_tokens_mask_input_pairs_and_bos_token(self):
6767
filtered_sequence = [x for x in filtered_sequence if x is not None]
6868
self.assertEqual(encoded_sequence, filtered_sequence)
6969

70-
@require_jinja
71-
def test_tokenization_for_chat(self):
72-
tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname)
73-
tokenizer.chat_template = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}"
74-
test_chats = [
75-
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
76-
[
77-
{"role": "system", "content": "You are a helpful chatbot."},
78-
{"role": "user", "content": "Hello!"},
79-
{"role": "assistant", "content": "Nice to meet you."},
80-
],
81-
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
82-
]
83-
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
84-
# fmt: off
85-
expected_tokens = [[1639, 389, 257, 7613, 8537, 13645, 13, 50256, 15496, 0, 50256], [1639, 389, 257, 7613, 8537, 13645, 13, 50256, 15496, 0, 50256, 35284, 284, 1826, 345, 13, 50256], [35284, 284, 1826, 345, 13, 50256, 15496, 0, 50256]]
86-
# fmt: on
87-
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
88-
self.assertListEqual(tokenized_chat, expected_tokens)
89-
9070
@require_tiktoken
9171
def test_tokenization_tiktoken(self):
9272
from tiktoken import encoding_name_for_model

tests/models/gpt_sw3/test_tokenization_gpt_sw3.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import unittest
1616

1717
from transformers import GPTSw3Tokenizer
18-
from transformers.testing_utils import get_tests_dir, require_jinja, require_sentencepiece, require_tokenizers, slow
18+
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
1919

2020
from ...test_tokenization_common import TokenizerTesterMixin
2121

@@ -129,36 +129,3 @@ def test_tokenizer_integration(self):
129129
model_name="AI-Sweden-Models/gpt-sw3-126m",
130130
sequences=sequences,
131131
)
132-
133-
@require_jinja
134-
def test_tokenization_for_chat(self):
135-
tokenizer = GPTSw3Tokenizer(SAMPLE_VOCAB, name_or_path="test")
136-
tokenizer.chat_template = (
137-
"{{ eos_token }}{{ bos_token }}"
138-
"{% for message in messages %}"
139-
"{% if message['role'] == 'user' %}{{ 'User: ' + message['content']}}"
140-
"{% else %}{{ 'Bot: ' + message['content']}}{% endif %}"
141-
"{{ message['text'] }}{{ bos_token }}"
142-
"{% endfor %}"
143-
"Bot:"
144-
)
145-
# This is in English, but it's just here to make sure the chat control tokens are being added properly
146-
test_chats = [
147-
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
148-
[
149-
{"role": "system", "content": "You are a helpful chatbot."},
150-
{"role": "user", "content": "Hello!"},
151-
{"role": "assistant", "content": "Nice to meet you."},
152-
],
153-
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
154-
]
155-
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
156-
# fmt: off
157-
expected_tokens = [
158-
[2000, 1, 575, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 1, 968, 263, 314, 419, 366, 354, 294, 360, 1, 575, 541, 419],
159-
[2000, 1, 575, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 1, 968, 263, 314, 419, 366, 354, 294, 360, 1, 575, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 1, 575, 541, 419],
160-
[2000, 1, 575, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 1, 968, 263, 314, 419, 366, 354, 294, 360, 1, 575, 541, 419]
161-
]
162-
# fmt: on
163-
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
164-
self.assertListEqual(tokenized_chat, expected_tokens)

tests/test_tokenization_common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,9 @@ def test_chat_template(self):
950950
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False
951951
)
952952
dict_output = tokenizer.apply_chat_template(
953-
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True
953+
dummy_conversation,
954+
chat_template=dummy_template,
955+
tokenize=True, # This also checks return_dict=True is the default
954956
)
955957
self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches
956958

0 commit comments

Comments
 (0)