Skip to content

Commit f80db75

Browse files
author
DvirDukhan
committed
added NA for torch and TFLITE
1 parent c35e52f commit f80db75

File tree

8 files changed

+80
-6
lines changed

8 files changed

+80
-6
lines changed

src/backends.c

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,17 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) {
224224
return REDISMODULE_ERR;
225225
}
226226

227+
backend.get_version =
228+
(const char *(*)(void))(unsigned long)dlsym(handle, "RAI_GetBackendVersionTFLite");
229+
if (backend.get_version == NULL) {
230+
dlclose(handle);
231+
RedisModule_Log(ctx, "warning",
232+
"Backend does not export RAI_GetBackendVersionTFLite. TFLite backend "
233+
"not loaded from %s",
234+
path);
235+
return REDISMODULE_ERR;
236+
}
237+
227238
RAI_backends.tflite = backend;
228239

229240
RedisModule_Log(ctx, "notice", "TFLITE backend loaded from %s", path);
@@ -338,6 +349,17 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) {
338349
return REDISMODULE_ERR;
339350
}
340351

352+
backend.get_version =
353+
(const char *(*)(void))(unsigned long)dlsym(handle, "RAI_GetBackendVersionTorch");
354+
if (backend.get_version == NULL) {
355+
dlclose(handle);
356+
RedisModule_Log(ctx, "warning",
357+
"Backend does not export RAI_GetBackendVersionTorch. TORCH backend "
358+
"not loaded from %s",
359+
path);
360+
return REDISMODULE_ERR;
361+
}
362+
341363
RAI_backends.torch = backend;
342364

343365
RedisModule_Log(ctx, "notice", "TORCH backend loaded from %s", path);

src/backends/tflite.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,5 @@ int RAI_ModelSerializeTFLite(RAI_Model *model, char **buffer, size_t *len, RAI_E
237237

238238
return 0;
239239
}
240+
241+
const char *RAI_GetBackendVersionTFLite(void) { return "NA"; }

src/backends/tflite.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
#ifndef SRC_BACKENDS_TFLITE_H_
2-
#define SRC_BACKENDS_TFLITE_H_
1+
#pragma once
32

43
#include "config.h"
54
#include "tensor_struct.h"
@@ -17,4 +16,4 @@ int RAI_ModelRunTFLite(RAI_ModelRunCtx **mctxs, RAI_Error *error);
1716

1817
int RAI_ModelSerializeTFLite(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error);
1918

20-
#endif /* SRC_BACKENDS_TFLITE_H_ */
19+
const char *RAI_GetBackendVersionTFLite(void);

src/backends/torch.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,5 @@ int RAI_ScriptRunTorch(RAI_ScriptRunCtx *sctx, RAI_Error *error) {
367367

368368
return 0;
369369
}
370+
371+
const char *RAI_GetBackendVersionTorch(void) { return "NA"; }

src/backends/torch.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
#ifndef SRC_BACKENDS_TORCH_H_
2-
#define SRC_BACKENDS_TORCH_H_
1+
#pragma once
32

43
#include "config.h"
54
#include "tensor_struct.h"
@@ -24,4 +23,4 @@ void RAI_ScriptFreeTorch(RAI_Script *script, RAI_Error *error);
2423

2524
int RAI_ScriptRunTorch(RAI_ScriptRunCtx *sctx, RAI_Error *error);
2625

27-
#endif /* SRC_BACKENDS_TORCH_H_ */
26+
const char *RAI_GetBackendVersionTorch(void);

src/redisai.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,16 @@ void _RedisAI_Info(RedisModuleCtx *ctx) {
844844
RedisModule_ReplyWithSimpleString(ctx, RAI_backends.tf.get_version());
845845
}
846846

847+
if (RAI_backends.torch.get_version) {
848+
RedisModule_ReplyWithSimpleString(ctx, "Torch version");
849+
RedisModule_ReplyWithSimpleString(ctx, RAI_backends.torch.get_version());
850+
}
851+
852+
if (RAI_backends.tflite.get_version) {
853+
RedisModule_ReplyWithSimpleString(ctx, "TFLite version");
854+
RedisModule_ReplyWithSimpleString(ctx, RAI_backends.tflite.get_version());
855+
}
856+
847857
if (RAI_backends.onnx.get_version) {
848858
RedisModule_ReplyWithSimpleString(ctx, "ONNX version");
849859
RedisModule_ReplyWithSimpleString(ctx, RAI_backends.onnx.get_version());

tests/flow/tests_pytorch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,3 +984,22 @@ def test_modelget_for_tuple_output(env):
984984
env.assertEqual(ret[9], 0)
985985
env.assertEqual(len(ret[11]), 2)
986986
env.assertEqual(len(ret[13]), 2)
987+
988+
def test_torch_info(env):
989+
if not TEST_PT:
990+
env.debugPrint("skipping {}".format(sys._getframe().f_code.co_name), force=True)
991+
return
992+
con = env.getConnection()
993+
994+
ret = con.execute_command('AI.INFO')
995+
env.assertEqual(6, len(ret))
996+
997+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
998+
model_filename = os.path.join(test_data_path, 'pt-minimal-bb.pt')
999+
with open(model_filename, 'rb') as f:
1000+
model_pb = f.read()
1001+
ret = con.execute_command('AI.MODELSET', 'm{1}', 'TORCH', DEVICE, 'BLOB', model_pb)
1002+
1003+
ret = con.execute_command('AI.INFO')
1004+
env.assertEqual(8, len(ret))
1005+
env.assertEqual(b'Torch version', ret[6])

tests/flow/tests_tflite.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,24 @@ def test_tflite_model_rdb_save_load(env):
354354
env.assertTrue(model_serialized_memory == model_serialized_after_rdbload)
355355
# Assert input model binary is equal to loaded model binary
356356
env.assertTrue(model_pb == model_serialized_after_rdbload)
357+
358+
def test_tflite_info(env):
359+
if not TEST_TFLITE:
360+
env.debugPrint("skipping {}".format(sys._getframe().f_code.co_name), force=True)
361+
return
362+
con = env.getConnection()
363+
364+
ret = con.execute_command('AI.INFO')
365+
env.assertEqual(6, len(ret))
366+
367+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
368+
model_filename = os.path.join(test_data_path, 'mnist_model_quant.tflite')
369+
370+
with open(model_filename, 'rb') as f:
371+
model_pb = f.read()
372+
373+
ret = con.execute_command('AI.MODELSET', 'mnist{1}', 'TFLITE', 'CPU', 'BLOB', model_pb)
374+
375+
ret = con.execute_command('AI.INFO')
376+
env.assertEqual(8, len(ret))
377+
env.assertEqual(b'TFLite version', ret[6])

0 commit comments

Comments
 (0)