Skip to content

Commit 49f0605

Browse files
committed
Check that blob len match the tensor dimensions and type in TENSORSET.
- Fix TENSORSET error tests so they would fail if no exception has occurred.
1 parent 0d00df4 commit 49f0605

File tree

3 files changed

+39
-38
lines changed

3 files changed

+39
-38
lines changed

src/DAG/dag.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
122122
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
123123
goto clean_up;
124124
}
125-
if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, tensor) !=
126-
REDISMODULE_OK) {
125+
if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, tensor) != REDISMODULE_OK) {
127126
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
128127
RedisModule_CloseKey(key);
129128
goto clean_up;

src/tensor.c

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,11 @@ int Tensor_DataTypeStr(DLDataType dtype, char *dtypestr) {
9797

9898
RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, int ndims,
9999
int tensorAllocMode) {
100-
const size_t dtypeSize = Tensor_DataTypeSize(dtype);
101-
if (dtypeSize == 0) {
102-
return NULL;
103-
}
104100

105101
RAI_Tensor *ret = RedisModule_Alloc(sizeof(*ret));
106102
int64_t *shape = RedisModule_Alloc(ndims * sizeof(*shape));
107103
int64_t *strides = RedisModule_Alloc(ndims * sizeof(*strides));
104+
size_t dtypeSize = Tensor_DataTypeSize(dtype);
108105

109106
size_t len = 1;
110107
for (int64_t i = 0; i < ndims; ++i) {
@@ -133,11 +130,6 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
133130
break;
134131
}
135132

136-
if (tensorAllocMode != TENSORALLOC_NONE && data == NULL) {
137-
RedisModule_Free(ret);
138-
return NULL;
139-
}
140-
141133
ret->tensor = (DLManagedTensor){.dl_tensor = (DLTensor){.ctx = ctx,
142134
.data = data,
143135
.ndim = ndims,
@@ -167,14 +159,10 @@ void RAI_RStringDataTensorDeleter(DLManagedTensor *arg) {
167159
RedisModule_Free(arg);
168160
}
169161

170-
RAI_Tensor *RAI_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, long long *dims, int ndims,
171-
RedisModuleString *rstr) {
172-
const size_t dtypeSize = Tensor_DataTypeSize(dtype);
173-
if (dtypeSize == 0) {
174-
return NULL;
175-
}
162+
RAI_Tensor *_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, size_t dtypeSize,
163+
long long *dims, int ndims,
164+
RedisModuleString *rstr, RAI_Error *err) {
176165

177-
RAI_Tensor *ret = RedisModule_Alloc(sizeof(*ret));
178166
int64_t *shape = RedisModule_Alloc(ndims * sizeof(*shape));
179167
int64_t *strides = RedisModule_Alloc(ndims * sizeof(*strides));
180168

@@ -189,12 +177,20 @@ RAI_Tensor *RAI_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, long long
189177
}
190178

191179
DLContext ctx = (DLContext){.device_type = kDLCPU, .device_id = 0};
192-
193-
long long nbytes = len*dtypeSize;
194-
180+
size_t nbytes = len * dtypeSize;
181+
182+
size_t blob_len;
183+
const char *blob = RedisModule_StringPtrLen(rstr, &blob_len);
184+
if (blob_len != nbytes) {
185+
RedisModule_Free(shape);
186+
RedisModule_Free(strides);
187+
RAI_SetError(err, RAI_ETENSORSET, "ERR data length does not match tensor shape and type");
188+
return NULL;
189+
}
195190
char *data = RedisModule_Alloc(nbytes);
196-
memcpy(data, RedisModule_StringPtrLen(rstr, NULL), nbytes);
191+
memcpy(data, blob, nbytes);
197192

193+
RAI_Tensor *ret = RedisModule_Alloc(sizeof(*ret));
198194
ret->tensor = (DLManagedTensor){.dl_tensor = (DLTensor){.ctx = ctx,
199195
.data = data,
200196
.ndim = ndims,
@@ -641,13 +637,16 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
641637
RAI_SetError(error, RAI_ETENSORSET, "wrong number of arguments for 'AI.TENSORSET' command");
642638
return -1;
643639
}
640+
644641
// get the tensor datatype
645642
const char *typestr = RedisModule_StringPtrLen(argv[2], NULL);
646-
size_t datasize = RAI_TensorDataSizeFromString(typestr);
647-
if (!datasize) {
643+
DLDataType datatype = RAI_TensorDataTypeFromString(typestr);
644+
size_t datasize = Tensor_DataTypeSize(datatype);
645+
if (datasize == 0) {
648646
RAI_SetError(error, RAI_ETENSORSET, "ERR invalid data type");
649647
return -1;
650648
}
649+
651650
const char *fmtstr;
652651
int datafmt = TENSOR_NONE;
653652
int tensorAllocMode = TENSORALLOC_CALLOC;
@@ -703,23 +702,18 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
703702
}
704703
}
705704

706-
const long long nbytes = len * datasize;
707-
size_t datalen;
708-
const char *data;
709-
DLDataType datatype = RAI_TensorDataTypeFromString(typestr);
710705
if (datafmt == TENSOR_BLOB) {
711706
RedisModuleString *rstr = argv[argpos];
712707
RedisModule_RetainString(NULL, rstr);
713-
*t = RAI_TensorCreateWithDLDataTypeAndRString(datatype, dims, ndims, rstr);
708+
*t = _TensorCreateWithDLDataTypeAndRString(datatype, datasize, dims, ndims, rstr, error);
714709
} else {
715710
*t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims, tensorAllocMode);
716711
}
717-
718-
if (!t) {
712+
if (!(*t)) {
719713
array_free(dims);
720-
RAI_SetError(error, RAI_ETENSORSET, "ERR could not create tensor");
721714
return -1;
722715
}
716+
723717
long i = 0;
724718
if (datafmt == TENSOR_VALUES) {
725719
for (; (argpos <= argc - 1) && (i < len); argpos++) {

tests/flow/tests_common.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def test_common_tensorset_error_replies(env):
5050
try:
5151
con.execute_command('SET','non-tensor','value')
5252
con.execute_command('AI.TENSORSET', 'non-tensor', 'INT32', 2, 'unsupported', 2, 3)
53+
env.assertFalse(True)
5354
except Exception as e:
5455
exception = e
5556
env.assertEqual(type(exception), redis.exceptions.ResponseError)
@@ -58,6 +59,7 @@ def test_common_tensorset_error_replies(env):
5859
# ERR invalid data type
5960
try:
6061
con.execute_command('AI.TENSORSET', 'z', 'INT128', 2, 'VALUES', 2, 3)
62+
env.assertFalse(True)
6163
except Exception as e:
6264
exception = e
6365
env.assertEqual(type(exception), redis.exceptions.ResponseError)
@@ -66,6 +68,7 @@ def test_common_tensorset_error_replies(env):
6668
# ERR invalid or negative value found in tensor shape
6769
try:
6870
con.execute_command('AI.TENSORSET', 'z', 'INT32', -1, 'VALUES', 2, 3)
71+
env.assertFalse(True)
6972
except Exception as e:
7073
exception = e
7174
env.assertEqual(type(exception), redis.exceptions.ResponseError)
@@ -74,6 +77,7 @@ def test_common_tensorset_error_replies(env):
7477
# ERR invalid argument found in tensor shape
7578
try:
7679
con.execute_command('AI.TENSORSET', 'z', 'INT32', 2, 'unsupported', 2, 3)
80+
env.assertFalse(True)
7781
except Exception as e:
7882
exception = e
7983
env.assertEqual(type(exception), redis.exceptions.ResponseError)
@@ -82,6 +86,7 @@ def test_common_tensorset_error_replies(env):
8286
# ERR invalid value
8387
try:
8488
con.execute_command('AI.TENSORSET', 'z', 'FLOAT', 2, 'VALUES', 2, 'A')
89+
env.assertFalse(True)
8590
except Exception as e:
8691
exception = e
8792
env.assertEqual(type(exception), redis.exceptions.ResponseError)
@@ -90,56 +95,58 @@ def test_common_tensorset_error_replies(env):
9095
# ERR invalid value
9196
try:
9297
con.execute_command('AI.TENSORSET', 'z', 'INT32', 2, 'VALUES', 2, 'A')
98+
env.assertFalse(True)
9399
except Exception as e:
94100
exception = e
95101
env.assertEqual(type(exception), redis.exceptions.ResponseError)
96102
env.assertEqual(exception.__str__(), "invalid value")
97103

98104
try:
99105
con.execute_command('AI.TENSORSET', 1)
106+
env.assertFalse(True)
100107
except Exception as e:
101108
exception = e
102109
env.assertEqual(type(exception), redis.exceptions.ResponseError)
103110

104111
try:
105112
con.execute_command('AI.TENSORSET', 'y', 'FLOAT')
106-
except Exception as e:
107-
exception = e
108-
env.assertEqual(type(exception), redis.exceptions.ResponseError)
109-
110-
try:
111-
con.execute_command('AI.TENSORSET', 'y', 'FLOAT', '2')
113+
env.assertFalse(True)
112114
except Exception as e:
113115
exception = e
114116
env.assertEqual(type(exception), redis.exceptions.ResponseError)
115117

116118
try:
117119
con.execute_command('AI.TENSORSET', 'y', 'FLOAT', 2, 'VALUES')
120+
env.assertFalse(True)
118121
except Exception as e:
119122
exception = e
120123
env.assertEqual(type(exception), redis.exceptions.ResponseError)
121124

122125
try:
123126
con.execute_command('AI.TENSORSET', 'y', 'FLOAT', 2, 'VALUES', 1)
127+
env.assertFalse(True)
124128
except Exception as e:
125129
exception = e
126130
env.assertEqual(type(exception), redis.exceptions.ResponseError)
127131

128132
try:
129133
con.execute_command('AI.TENSORSET', 'y', 'FLOAT', 2, 'VALUES', '1')
134+
env.assertFalse(True)
130135
except Exception as e:
131136
exception = e
132137
env.assertEqual(type(exception), redis.exceptions.ResponseError)
133138

134139
try:
135140
con.execute_command('AI.TENSORSET', 'blob_tensor_moreargs', 'FLOAT', 2, 'BLOB', '\x00', 'extra-argument')
141+
env.assertFalse(True)
136142
except Exception as e:
137143
exception = e
138144
env.assertEqual(type(exception), redis.exceptions.ResponseError)
139145
env.assertEqual("wrong number of arguments for 'AI.TENSORSET' command", exception.__str__())
140146

141147
try:
142148
con.execute_command('AI.TENSORSET', 'blob_tensor_lessargs', 'FLOAT', 2, 'BLOB')
149+
env.assertFalse(True)
143150
except Exception as e:
144151
exception = e
145152
env.assertEqual(type(exception), redis.exceptions.ResponseError)
@@ -148,6 +155,7 @@ def test_common_tensorset_error_replies(env):
148155
# ERR data length does not match tensor shape and type
149156
try:
150157
con.execute_command('AI.TENSORSET', 'sample_raw_wrong_blob_for_dim', 'FLOAT', 1, 1, 28, 280, 'BLOB', sample_raw)
158+
env.assertFalse(True)
151159
except Exception as e:
152160
exception = e
153161
env.assertEqual(type(exception), redis.exceptions.ResponseError)

0 commit comments

Comments
 (0)