@@ -348,3 +348,34 @@ def TensorCreate_FromBlob(record):
348348
349349 values = con .execute_command ('AI.TENSORGET' , 'test2_res{1}' , 'VALUES' )
350350 env .assertEqual (values , [5 , 6 , 7 , 8 ])
351+
352+
353+ @skip_if_gears_not_loaded
354+ def test_flatten_tensor_via_gears (env ):
355+ script = '''
356+
357+ import redisAI
358+
359+ def FlattenTensor(record):
360+
361+ tensor = redisAI.createTensorFromValues('DOUBLE', [2,2], [1.0, 2.0, 3.0, 4.0])
362+ tensor_as_list = redisAI.tensorToFlatList(tensor)
363+ if tensor_as_list != [1.0, 2.0, 3.0, 4.0]:
364+ return "ERROR failed to flatten tensor to list of doubles"
365+
366+ tensor_blob = bytearray([5, 0, 6, 0, 7, 0, 8, 0])
367+ tensor = redisAI.createTensorFromBlob('UINT16', [2,2], tensor_blob)
368+ tensor_as_list = redisAI.tensorToFlatList(tensor)
369+ if tensor_as_list != [5, 6, 7, 8]:
370+ return "ERROR failed to flatten tensor to list of long long"
371+ return "test_OK"
372+
373+
374+ GB("CommandReader").map(FlattenTensor).register(trigger="FlattenTensor_test")
375+ '''
376+
377+ con = env .getConnection ()
378+ ret = con .execute_command ('rg.pyexecute' , script )
379+ env .assertEqual (ret , b'OK' )
380+ ret = con .execute_command ('rg.trigger' , 'FlattenTensor_test' )
381+ env .assertEqual (ret [0 ], b'test_OK' )
0 commit comments