Skip to content

Commit 32547fb

Browse files
committed
accept char and cellstr input for generate
1 parent 5200776 commit 32547fb

File tree

7 files changed

+21
-12
lines changed

7 files changed

+21
-12
lines changed

+llms/+utils/mustBeNonzeroLengthTextScalar.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,8 @@ function mustBeNonzeroLengthTextScalar(content)
55

66
% Copyright 2024 The MathWorks, Inc.
77
mustBeNonzeroLengthText(content)
8+
if iscellstr(content)
9+
content = string(content);
10+
end
811
mustBeTextScalar(content)
912
end

azureChat.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,14 @@
171171

172172
arguments
173173
this (1,1) azureChat
174-
messages (1,1) {mustBeValidMsgs}
174+
messages {mustBeValidMsgs}
175175
nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1
176176
nvp.MaxNumTokens (1,1) {mustBePositive} = inf
177177
nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = []
178178
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
179179
end
180180

181+
messages = convertCharsToStrings(messages);
181182
if isstring(messages) && isscalar(messages)
182183
messagesStruct = {struct("role", "user", "content", messages)};
183184
else

ollamaChat.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,13 @@
135135

136136
arguments
137137
this (1,1) ollamaChat
138-
messages (1,1) {mustBeValidMsgs}
138+
messages {mustBeValidMsgs}
139139
nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1
140140
nvp.MaxNumTokens (1,1) {mustBePositive} = inf
141141
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
142142
end
143143

144+
messages = convertCharsToStrings(messages);
144145
if isstring(messages) && isscalar(messages)
145146
messagesStruct = {struct("role", "user", "content", messages)};
146147
else

openAIChat.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@
168168

169169
arguments
170170
this (1,1) openAIChat
171-
messages (1,1) {mustBeValidMsgs}
171+
messages {mustBeValidMsgs}
172172
nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1
173173
nvp.MaxNumTokens (1,1) {mustBePositive} = inf
174174
nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = []
@@ -177,6 +177,7 @@
177177

178178
toolChoice = convertToolChoice(this, nvp.ToolChoice);
179179

180+
messages = convertCharsToStrings(messages);
180181
if isstring(messages) && isscalar(messages)
181182
messagesStruct = {struct("role", "user", "content", messages)};
182183
else

tests/tazureChat.m

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
InvalidConstructorInput = iGetInvalidConstructorInput;
88
InvalidGenerateInput = iGetInvalidGenerateInput;
99
InvalidValuesSetters = iGetInvalidValuesSetters;
10+
StringInputs = struct('string',{"hi"},'char',{'hi'},'cellstr',{{'hi'}});
1011
end
1112

1213
methods(Test)
@@ -32,10 +33,10 @@ function constructChatWithAllNVP(testCase)
3233
testCase.verifyEqual(chat.PresencePenalty, presenceP);
3334
end
3435

35-
function doGenerate(testCase)
36+
function doGenerate(testCase,StringInputs)
3637
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
3738
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"));
38-
response = testCase.verifyWarningFree(@() generate(chat,"hi"));
39+
response = testCase.verifyWarningFree(@() generate(chat,StringInputs));
3940
testCase.verifyClass(response,'string');
4041
testCase.verifyGreaterThan(strlength(response),0);
4142
end

tests/tollamaChat.m

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
InvalidGenerateInput = iGetInvalidGenerateInput;
99
InvalidValuesSetters = iGetInvalidValuesSetters;
1010
ValidValuesSetters = iGetValidValuesSetters;
11+
StringInputs = struct('string',{"hi"},'char',{'hi'},'cellstr',{{'hi'}});
1112
end
1213

1314
methods(Test)
@@ -31,9 +32,9 @@ function constructChatWithAllNVP(testCase)
3132
testCase.verifyEqual(chat.StopSequences, stop);
3233
end
3334

34-
function doGenerate(testCase)
35+
function doGenerate(testCase,StringInputs)
3536
chat = ollamaChat("mistral");
36-
response = testCase.verifyWarningFree(@() generate(chat,"hi"));
37+
response = testCase.verifyWarningFree(@() generate(chat,StringInputs));
3738
testCase.verifyClass(response,'string');
3839
testCase.verifyGreaterThan(strlength(response),0);
3940
end
@@ -290,7 +291,7 @@ function queryModels(testCase)
290291
invalidGenerateInput = struct( ...
291292
"EmptyInput",struct( ...
292293
"Input",{{ [] }},...
293-
"Error","MATLAB:validation:IncompatibleSize"),...
294+
"Error","llms:mustBeMessagesOrTxt"),...
294295
...
295296
"InvalidInputType",struct( ...
296297
"Input",{{ 123 }},...

tests/topenAIChat.m

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
properties(TestParameter)
77
ValidConstructorInput = iGetValidConstructorInput();
88
InvalidConstructorInput = iGetInvalidConstructorInput();
9-
InvalidGenerateInput = iGetInvalidGenerateInput();
10-
InvalidValuesSetters = iGetInvalidValuesSetters();
9+
InvalidGenerateInput = iGetInvalidGenerateInput();
10+
InvalidValuesSetters = iGetInvalidValuesSetters();
11+
StringInputs = struct('string',{"hi"},'char',{'hi'},'cellstr',{{'hi'}});
1112
end
1213

1314
methods(Test)
1415
% Test methods
15-
function generateAcceptsSingleStringAsInput(testCase)
16+
function generateAcceptsSingleStringAsInput(testCase,StringInputs)
1617
chat = openAIChat;
17-
testCase.verifyWarningFree(@()generate(chat,"This is okay"));
18+
testCase.verifyWarningFree(@()generate(chat,StringInputs));
1819
end
1920

2021
function generateAcceptsMessagesAsInput(testCase)

0 commit comments

Comments
 (0)