Skip to content

Commit ac14556

Browse files
author
DvirDukhan
committed
aof, rdb
1 parent 496494d commit ac14556

File tree

13 files changed

+479
-42
lines changed

13 files changed

+479
-42
lines changed

src/redisai.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
671671
return REDISMODULE_OK;
672672
}
673673

674-
int outentries = source ? 6 : 4;
674+
int outentries = source ? 8 : 6;
675675

676676
RedisModule_ReplyWithArray(ctx, outentries);
677677
RedisModule_ReplyWithCString(ctx, "device");
@@ -682,6 +682,12 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
682682
RedisModule_ReplyWithCString(ctx, "source");
683683
RedisModule_ReplyWithCString(ctx, sto->scriptdef);
684684
}
685+
RedisModule_ReplyWithCString(ctx, "Entry Points");
686+
size_t nEntryPoints = array_len(sto->entryPoints);
687+
RedisModule_ReplyWithArray(ctx, nEntryPoints);
688+
for (size_t i = 0; i < nEntryPoints; i++) {
689+
RedisModule_ReplyWithCString(ctx, sto->entryPoints[i]);
690+
}
685691
return REDISMODULE_OK;
686692
}
687693

src/serialization/AOF/rai_aof_rewrite.c

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,30 @@ void RAI_AOFRewriteModel(RedisModuleIO *aof, RedisModuleString *key, void *value
110110

111111
void RAI_AOFRewriteScript(RedisModuleIO *aof, RedisModuleString *key, void *value) {
112112
RAI_Script *script = (RAI_Script *)value;
113+
RedisModuleString **args = array_new(RedisModuleString *, 4);
114+
args = array_append(args, RedisModule_CreateStringFromString(NULL, key));
115+
args = array_append(
116+
args, RedisModule_CreateString(NULL, script->devicestr, strlen(script->devicestr)));
117+
args = array_append(args, RedisModule_CreateString(NULL, "TAG", strlen("TAG")));
118+
args = array_append(args, RedisModule_CreateStringFromString(NULL, script->tag));
113119
size_t nEntryPoints = array_len(script->entryPoints);
114-
array_new_on_stack(RedisModuleString*, nEntryPoints, entryPoints);
115-
for (size_t i =0; i< nEntryPoints; i++) {
116-
entryPoints = array_append(entryPoints, RedisModule_CreateString(NULL, script->entryPoints[i], strlen(script->entryPoints[i])));
120+
if (nEntryPoints > 0) {
121+
args = array_append(args,
122+
RedisModule_CreateString(NULL, "ENTRY_POINTS", strlen("ENTRY_POINTS")));
123+
args =
124+
array_append(args, RedisModule_CreateStringFromLongLong(NULL, (long long)nEntryPoints));
125+
for (size_t i = 0; i < nEntryPoints; i++) {
126+
args = array_append(args, RedisModule_CreateString(NULL, script->entryPoints[i],
127+
strlen(script->entryPoints[i])));
128+
}
117129
}
118-
RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "sccsvlcc", key, script->devicestr, "TAG", script->tag, entryPoints, nEntryPoints,
119-
"SOURCE", script->scriptdef);
120-
for(size_t i=0; i< nEntryPoints; i++) {
121-
RedisModule_FreeString(NULL, entryPoints[i]);
130+
args = array_append(args, RedisModule_CreateString(NULL, "SOURCE", strlen("SOURCE")));
131+
args = array_append(
132+
args, RedisModule_CreateString(NULL, script->scriptdef, strlen(script->scriptdef)));
133+
134+
RedisModule_EmitAOF(aof, "AI.SCRIPTSTORE", "v", args);
135+
for (size_t i = 0; i < array_len(args); i++) {
136+
RedisModule_FreeString(NULL, args[i]);
122137
}
123-
array_free(entryPoints);
138+
array_free(args);
124139
}
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
#include "decode_v3.h"
2+
#include "assert.h"
3+
4+
/**
5+
* In case of IO errors, the default return values are:
6+
* numbers - 0
7+
* strings - null
8+
* So only when it is necessary check for IO errors.
9+
*/
10+
11+
void *RAI_RDBLoadTensor_v3(RedisModuleIO *io) {
12+
int64_t *shape = NULL;
13+
int64_t *strides = NULL;
14+
15+
DLDevice device;
16+
device.device_type = RedisModule_LoadUnsigned(io);
17+
device.device_id = RedisModule_LoadUnsigned(io);
18+
if (RedisModule_IsIOError(io))
19+
goto cleanup;
20+
21+
// For now we only support CPU tensors (except during model and script run)
22+
assert(device.device_type == kDLCPU);
23+
assert(device.device_id == 0);
24+
25+
DLDataType dtype;
26+
dtype.bits = RedisModule_LoadUnsigned(io);
27+
dtype.code = RedisModule_LoadUnsigned(io);
28+
dtype.lanes = RedisModule_LoadUnsigned(io);
29+
30+
size_t ndims = RedisModule_LoadUnsigned(io);
31+
if (RedisModule_IsIOError(io))
32+
goto cleanup;
33+
34+
shape = RedisModule_Calloc(ndims, sizeof(*shape));
35+
for (size_t i = 0; i < ndims; ++i) {
36+
shape[i] = RedisModule_LoadUnsigned(io);
37+
}
38+
39+
strides = RedisModule_Calloc(ndims, sizeof(*strides));
40+
for (size_t i = 0; i < ndims; ++i) {
41+
strides[i] = RedisModule_LoadUnsigned(io);
42+
}
43+
44+
size_t byte_offset = RedisModule_LoadUnsigned(io);
45+
46+
size_t len;
47+
char *data = RedisModule_LoadStringBuffer(io, &len);
48+
if (RedisModule_IsIOError(io))
49+
goto cleanup;
50+
51+
RAI_Tensor *ret = RAI_TensorNew();
52+
ret->tensor = (DLManagedTensor){.dl_tensor = (DLTensor){.device = device,
53+
.data = data,
54+
.ndim = ndims,
55+
.dtype = dtype,
56+
.shape = shape,
57+
.strides = strides,
58+
.byte_offset = byte_offset},
59+
.manager_ctx = NULL,
60+
.deleter = NULL};
61+
return ret;
62+
63+
cleanup:
64+
if (shape)
65+
RedisModule_Free(shape);
66+
if (strides)
67+
RedisModule_Free(strides);
68+
RedisModule_LogIOError(io, "error", "Experienced a short read while reading a tensor from RDB");
69+
return NULL;
70+
}
71+
72+
void *RAI_RDBLoadModel_v3(RedisModuleIO *io) {
73+
74+
char *devicestr = NULL;
75+
RedisModuleString *tag = NULL;
76+
size_t ninputs = 0;
77+
const char **inputs = NULL;
78+
size_t noutputs = 0;
79+
const char **outputs = NULL;
80+
char *buffer = NULL;
81+
82+
RAI_Backend backend = RedisModule_LoadUnsigned(io);
83+
devicestr = RedisModule_LoadStringBuffer(io, NULL);
84+
tag = RedisModule_LoadString(io);
85+
86+
const size_t batchsize = RedisModule_LoadUnsigned(io);
87+
const size_t minbatchsize = RedisModule_LoadUnsigned(io);
88+
const size_t minbatchtimeout = RedisModule_LoadUnsigned(io);
89+
90+
ninputs = RedisModule_LoadUnsigned(io);
91+
if (RedisModule_IsIOError(io))
92+
goto cleanup;
93+
94+
inputs = RedisModule_Alloc(ninputs * sizeof(char *));
95+
96+
for (size_t i = 0; i < ninputs; i++) {
97+
inputs[i] = RedisModule_LoadStringBuffer(io, NULL);
98+
}
99+
100+
noutputs = RedisModule_LoadUnsigned(io);
101+
if (RedisModule_IsIOError(io))
102+
goto cleanup;
103+
104+
outputs = RedisModule_Alloc(noutputs * sizeof(char *));
105+
106+
for (size_t i = 0; i < noutputs; i++) {
107+
outputs[i] = RedisModule_LoadStringBuffer(io, NULL);
108+
}
109+
110+
RAI_ModelOpts opts = {
111+
.batchsize = batchsize,
112+
.minbatchsize = minbatchsize,
113+
.minbatchtimeout = minbatchtimeout,
114+
.backends_intra_op_parallelism = Config_GetBackendsIntraOpParallelism(),
115+
.backends_inter_op_parallelism = Config_GetBackendsInterOpParallelism(),
116+
};
117+
118+
size_t len = RedisModule_LoadUnsigned(io);
119+
if (RedisModule_IsIOError(io))
120+
goto cleanup;
121+
122+
buffer = RedisModule_Alloc(len);
123+
const size_t n_chunks = RedisModule_LoadUnsigned(io);
124+
long long chunk_offset = 0;
125+
for (size_t i = 0; i < n_chunks; i++) {
126+
size_t chunk_len;
127+
char *chunk_buffer = RedisModule_LoadStringBuffer(io, &chunk_len);
128+
if (RedisModule_IsIOError(io))
129+
goto cleanup;
130+
memcpy(buffer + chunk_offset, chunk_buffer, chunk_len);
131+
chunk_offset += chunk_len;
132+
RedisModule_Free(chunk_buffer);
133+
}
134+
135+
RAI_Error err = {0};
136+
RAI_Model *model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs,
137+
outputs, buffer, len, &err);
138+
139+
if (err.code == RAI_EBACKENDNOTLOADED) {
140+
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
141+
int ret = RAI_LoadDefaultBackend(ctx, backend);
142+
if (ret == REDISMODULE_ERR) {
143+
RedisModule_Log(ctx, "warning", "Could not load default backend");
144+
RAI_ClearError(&err);
145+
goto cleanup;
146+
}
147+
RAI_ClearError(&err);
148+
model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs,
149+
buffer, len, &err);
150+
}
151+
152+
if (err.code != RAI_OK) {
153+
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
154+
RedisModule_Log(ctx, "warning", "%s", err.detail);
155+
RAI_ClearError(&err);
156+
goto cleanup;
157+
}
158+
159+
RedisModuleCtx *stats_ctx = RedisModule_GetContextFromIO(io);
160+
RedisModuleString *stats_keystr =
161+
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
162+
163+
model->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, devicestr, tag);
164+
165+
for (size_t i = 0; i < ninputs; i++) {
166+
RedisModule_Free((void *)inputs[i]);
167+
}
168+
RedisModule_Free(inputs);
169+
for (size_t i = 0; i < noutputs; i++) {
170+
RedisModule_Free((void *)outputs[i]);
171+
}
172+
RedisModule_Free(outputs);
173+
RedisModule_Free(buffer);
174+
RedisModule_Free(devicestr);
175+
RedisModule_FreeString(NULL, stats_keystr);
176+
RedisModule_FreeString(NULL, tag);
177+
178+
return model;
179+
180+
cleanup:
181+
if (devicestr)
182+
RedisModule_Free(devicestr);
183+
if (tag)
184+
RedisModule_FreeString(NULL, tag);
185+
if (inputs) {
186+
for (size_t i = 0; i < ninputs; i++) {
187+
RedisModule_Free((void *)inputs[i]);
188+
}
189+
RedisModule_Free(inputs);
190+
}
191+
192+
if (outputs) {
193+
for (size_t i = 0; i < noutputs; i++) {
194+
RedisModule_Free((void *)outputs[i]);
195+
}
196+
RedisModule_Free(outputs);
197+
}
198+
199+
if (buffer)
200+
RedisModule_Free(buffer);
201+
202+
RedisModule_LogIOError(io, "error", "Experienced a short read while reading a model from RDB");
203+
return NULL;
204+
}
205+
206+
void *RAI_RDBLoadScript_v3(RedisModuleIO *io) {
207+
RedisModuleString *tag = NULL;
208+
char *devicestr = NULL;
209+
char *scriptdef = NULL;
210+
size_t nEntryPoints = 0;
211+
char **entryPoints = NULL;
212+
RAI_Error err = {0};
213+
214+
size_t len;
215+
devicestr = RedisModule_LoadStringBuffer(io, &len);
216+
tag = RedisModule_LoadString(io);
217+
218+
scriptdef = RedisModule_LoadStringBuffer(io, &len);
219+
if (RedisModule_IsIOError(io))
220+
goto cleanup;
221+
222+
nEntryPoints = (size_t)RedisModule_LoadUnsigned(io);
223+
entryPoints = array_new(char *, nEntryPoints);
224+
225+
for (size_t i = 0; i < nEntryPoints; i++) {
226+
char *entryPoint = RedisModule_LoadStringBuffer(io, &len);
227+
if (RedisModule_IsIOError(io)) {
228+
goto cleanup;
229+
}
230+
entryPoints = array_append(entryPoints, entryPoint);
231+
}
232+
233+
RAI_Script *script = RAI_ScriptCompile(devicestr, tag, scriptdef, (const char **)entryPoints,
234+
nEntryPoints, &err);
235+
236+
if (err.code == RAI_EBACKENDNOTLOADED) {
237+
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
238+
int ret = RAI_LoadDefaultBackend(ctx, RAI_BACKEND_TORCH);
239+
if (ret == REDISMODULE_ERR) {
240+
RedisModule_Log(ctx, "warning", "Could not load default TORCH backend\n");
241+
RAI_ClearError(&err);
242+
goto cleanup;
243+
}
244+
RAI_ClearError(&err);
245+
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
246+
}
247+
248+
if (err.code != RAI_OK) {
249+
printf("ERR: %s\n", err.detail);
250+
RAI_ClearError(&err);
251+
goto cleanup;
252+
}
253+
254+
RedisModuleCtx *stats_ctx = RedisModule_GetContextFromIO(io);
255+
RedisModuleString *stats_keystr =
256+
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
257+
258+
script->infokey =
259+
RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_SCRIPT, RAI_BACKEND_TORCH, devicestr, tag);
260+
261+
RedisModule_FreeString(NULL, stats_keystr);
262+
RedisModule_FreeString(NULL, tag);
263+
RedisModule_Free(devicestr);
264+
RedisModule_Free(scriptdef);
265+
for (size_t i = 0; i < nEntryPoints; i++) {
266+
RedisModule_Free(entryPoints[i]);
267+
}
268+
array_free(entryPoints);
269+
return script;
270+
cleanup:
271+
if (devicestr)
272+
RedisModule_Free(devicestr);
273+
if (scriptdef)
274+
RedisModule_Free(scriptdef);
275+
if (tag)
276+
RedisModule_FreeString(NULL, tag);
277+
if (entryPoints) {
278+
for (size_t i = 0; i < nEntryPoints; i++) {
279+
RedisModule_Free(entryPoints[i]);
280+
}
281+
array_free(entryPoints);
282+
}
283+
return NULL;
284+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#pragma once
2+
#include "serialization/serialization_include.h"
3+
4+
void *RAI_RDBLoadTensor_v3(RedisModuleIO *io);
5+
6+
void *RAI_RDBLoadModel_v3(RedisModuleIO *io);
7+
8+
void *RAI_RDBLoadScript_v3(RedisModuleIO *io);

src/serialization/RDB/decoder/decode_previous.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ void *Decode_PreviousModel(RedisModuleIO *rdb, int encver) {
2323
return RAI_RDBLoadModel_v0(rdb);
2424
case 1:
2525
return RAI_RDBLoadModel_v1(rdb);
26+
case 2:
27+
return RAI_RDBLoadModel_v2(rdb);
2628
default:
2729
assert(false && "Invalid encoding version");
2830
}
@@ -35,6 +37,8 @@ void *Decode_PreviousScript(RedisModuleIO *rdb, int encver) {
3537
return RAI_RDBLoadScript_v0(rdb);
3638
case 1:
3739
return RAI_RDBLoadScript_v1(rdb);
40+
case 2:
41+
return RAI_RDBLoadScript_v2(rdb);
3842
default:
3943
assert(false && "Invalid encoding version");
4044
}
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include "rai_rdb_decoder.h"
2-
#include "current/v2/decode_v2.h"
2+
#include "current/v3/decode_v3.h"
33

4-
void *RAI_RDBLoadTensor(RedisModuleIO *io) { return RAI_RDBLoadTensor_v2(io); }
4+
void *RAI_RDBLoadTensor(RedisModuleIO *io) { return RAI_RDBLoadTensor_v3(io); }
55

6-
void *RAI_RDBLoadModel(RedisModuleIO *io) { return RAI_RDBLoadModel_v2(io); }
6+
void *RAI_RDBLoadModel(RedisModuleIO *io) { return RAI_RDBLoadModel_v3(io); }
77

8-
void *RAI_RDBLoadScript(RedisModuleIO *io) { return RAI_RDBLoadScript_v2(io); }
8+
void *RAI_RDBLoadScript(RedisModuleIO *io) { return RAI_RDBLoadScript_v3(io); }
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include "rai_rdb_encode.h"
2-
#include "v2/encode_v2.h"
2+
#include "v3/encode_v3.h"
33

4-
void RAI_RDBSaveTensor(RedisModuleIO *io, void *value) { RAI_RDBSaveTensor_v2(io, value); }
4+
void RAI_RDBSaveTensor(RedisModuleIO *io, void *value) { RAI_RDBSaveTensor_v3(io, value); }
55

6-
void RAI_RDBSaveModel(RedisModuleIO *io, void *value) { RAI_RDBSaveModel_v2(io, value); }
6+
void RAI_RDBSaveModel(RedisModuleIO *io, void *value) { RAI_RDBSaveModel_v3(io, value); }
77

8-
void RAI_RDBSaveScript(RedisModuleIO *io, void *value) { RAI_RDBSaveScript_v2(io, value); }
8+
void RAI_RDBSaveScript(RedisModuleIO *io, void *value) { RAI_RDBSaveScript_v3(io, value); }

0 commit comments

Comments
 (0)