Skip to content

Commit 1c677ea

Browse files
committed
Test and small refactor creating tensor from values through gears.
1 parent 7564b42 commit 1c677ea

File tree

3 files changed

+34
-31
lines changed

3 files changed

+34
-31
lines changed

src/tensor.c

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ RAI_Tensor *RAI_TensorNew(void) {
101101
ret->len = LEN_UNKOWN;
102102
}
103103

104-
RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, int ndims,
105-
int tensorAllocMode) {
104+
RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, int ndims) {
106105

107106
size_t dtypeSize = Tensor_DataTypeSize(dtype);
108107
if (dtypeSize == 0) {
@@ -124,21 +123,7 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
124123
}
125124

126125
DLContext ctx = (DLContext){.device_type = kDLCPU, .device_id = 0};
127-
void *data = NULL;
128-
switch (tensorAllocMode) {
129-
case TENSORALLOC_ALLOC:
130-
data = RedisModule_Alloc(len * dtypeSize);
131-
break;
132-
case TENSORALLOC_CALLOC:
133-
data = RedisModule_Calloc(len, dtypeSize);
134-
break;
135-
case TENSORALLOC_NONE:
136-
/* shallow copy no alloc */
137-
default:
138-
/* assume TENSORALLOC_NONE
139-
shallow copy no alloc */
140-
break;
141-
}
126+
void *data = RedisModule_Alloc(len * dtypeSize);
142127

143128
ret->tensor = (DLManagedTensor){.dl_tensor = (DLTensor){.ctx = ctx,
144129
.data = data,
@@ -214,9 +199,9 @@ RAI_Tensor *_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, size_t dtype
214199
return ret;
215200
}
216201

217-
RAI_Tensor *RAI_TensorCreate(const char *dataType, long long *dims, int ndims, int hasdata) {
202+
RAI_Tensor *RAI_TensorCreate(const char *dataType, long long *dims, int ndims) {
218203
DLDataType dtype = RAI_TensorDataTypeFromString(dataType);
219-
return RAI_TensorCreateWithDLDataType(dtype, dims, ndims, TENSORALLOC_ALLOC);
204+
return RAI_TensorCreateWithDLDataType(dtype, dims, ndims);
220205
}
221206

222207
#if 0
@@ -273,7 +258,7 @@ RAI_Tensor *RAI_TensorCreateByConcatenatingTensors(RAI_Tensor **ts, long long n)
273258

274259
DLDataType dtype = RAI_TensorDataType(ts[0]);
275260

276-
RAI_Tensor *ret = RAI_TensorCreateWithDLDataType(dtype, dims, ndims, TENSORALLOC_ALLOC);
261+
RAI_Tensor *ret = RAI_TensorCreateWithDLDataType(dtype, dims, ndims);
277262

278263
for (long long i = 0; i < n; i++) {
279264
memcpy(RAI_TensorData(ret) + batch_offsets[i] * sample_size * dtype_size,
@@ -300,7 +285,7 @@ RAI_Tensor *RAI_TensorCreateBySlicingTensor(RAI_Tensor *t, long long offset, lon
300285

301286
DLDataType dtype = RAI_TensorDataType(t);
302287

303-
RAI_Tensor *ret = RAI_TensorCreateWithDLDataType(dtype, dims, ndims, TENSORALLOC_ALLOC);
288+
RAI_Tensor *ret = RAI_TensorCreateWithDLDataType(dtype, dims, ndims);
304289

305290
memcpy(RAI_TensorData(ret), RAI_TensorData(t) + offset * sample_size * dtype_size,
306291
len * sample_size * dtype_size);
@@ -329,7 +314,7 @@ int RAI_TensorDeepCopy(RAI_Tensor *t, RAI_Tensor **dest) {
329314

330315
DLDataType dtype = RAI_TensorDataType(t);
331316

332-
RAI_Tensor *ret = RAI_TensorCreateWithDLDataType(dtype, dims, ndims, TENSORALLOC_ALLOC);
317+
RAI_Tensor *ret = RAI_TensorCreateWithDLDataType(dtype, dims, ndims);
333318

334319
memcpy(RAI_TensorData(ret), RAI_TensorData(t), sample_size * dtype_size);
335320
*dest = ret;
@@ -642,7 +627,6 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
642627

643628
const char *fmtstr;
644629
int datafmt = TENSOR_NONE;
645-
int tensorAllocMode = TENSORALLOC_CALLOC;
646630
size_t ndims = 0;
647631
long long len = 1;
648632
long long *dims = (long long *)array_new(long long, 1);
@@ -656,7 +640,6 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
656640
remaining_args = argc - 1 - argpos;
657641
if (!strcasecmp(opt, "BLOB")) {
658642
datafmt = TENSOR_BLOB;
659-
tensorAllocMode = TENSORALLOC_CALLOC;
660643
// if we've found the dataformat there are no more dimensions
661644
// check right away if the arity is correct
662645
if (remaining_args != 1 && enforceArity == 1) {
@@ -669,7 +652,6 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
669652
break;
670653
} else if (!strcasecmp(opt, "VALUES")) {
671654
datafmt = TENSOR_VALUES;
672-
tensorAllocMode = TENSORALLOC_CALLOC;
673655
// if we've found the dataformat there are no more dimensions
674656
// check right away if the arity is correct
675657
if (remaining_args != len && enforceArity == 1) {
@@ -699,7 +681,7 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
699681
RedisModuleString *rstr = argv[argpos];
700682
*t = _TensorCreateWithDLDataTypeAndRString(datatype, datasize, dims, ndims, rstr, error);
701683
} else {
702-
*t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims, tensorAllocMode);
684+
*t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims);
703685
}
704686
if (!(*t)) {
705687
array_free(dims);

src/tensor.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,10 @@ RAI_Tensor *RAI_TensorNew(void);
6565
* @param dataType string containing the numeric data type of tensor elements
6666
* @param dims n-dimensional array ( the dimension values are copied )
6767
* @param ndims number of dimensions
68-
* @param hasdata ( deprecated parameter )
6968
* @return allocated RAI_Tensor on success, or NULL if the allocation
7069
* failed.
7170
*/
72-
RAI_Tensor *RAI_TensorCreate(const char *dataType, long long *dims, int ndims, int hasdata);
71+
RAI_Tensor *RAI_TensorCreate(const char *dataType, long long *dims, int ndims);
7372

7473
/**
7574
* Allocate the memory and initialise the RAI_Tensor. Creates a tensor based on
@@ -81,12 +80,10 @@ RAI_Tensor *RAI_TensorCreate(const char *dataType, long long *dims, int ndims, i
8180
* @param dtype DLDataType
8281
* @param dims n-dimensional array ( the dimension values are copied )
8382
* @param ndims number of dimensions
84-
* @param tensorAllocMode
8583
* @return allocated RAI_Tensor on success, or NULL if the allocation
8684
* failed.
8785
*/
88-
RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, int ndims,
89-
int tensorAllocMode);
86+
RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, int ndims);
9087

9188
/**
9289
* Allocate the memory for a new Tensor and copy data fom a tensor to it.

tests/flow/tests_withGears.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,27 @@ async def DAGRun_addOpsFromString(record):
324324
values = con.execute_command('AI.TENSORGET', 'test5_res{1}', 'VALUES')
325325
env.assertEqual(values, [b'4', b'9', b'4', b'9'])
326326

327+
328+
@skip_if_gears_not_loaded
329+
def test_tensor_create_via_gears(env):
330+
script = '''
331+
332+
import redisAI
333+
334+
def TensorCreate_FromValues(record):
335+
336+
tensor = redisAI.createTensorFromValues('DOUBLE', [2,2], [1.0, 2.0, 3.0, 4.0])
337+
redisAI.setTensorInKey('test1_res{1}', tensor)
338+
return "test1_OK"
339+
340+
GB("CommandReader").map(TensorCreate_FromValues).register(trigger="TensorCreate_FromValues_test1")
341+
'''
342+
343+
con = env.getConnection()
344+
ret = con.execute_command('rg.pyexecute', script)
345+
env.assertEqual(ret, b'OK')
346+
ret = con.execute_command('rg.trigger', 'TensorCreate_FromValues_test1')
347+
env.assertEqual(ret[0], b'test1_OK')
348+
349+
values = con.execute_command('AI.TENSORGET', 'test1_res{1}', 'VALUES')
350+
env.assertEqual(values, [b'1', b'2', b'3', b'4'])

0 commit comments

Comments
 (0)