Skip to content

Commit 1c21e9d

Browse files
author
DvirDukhan
committed
fixed typos
1 parent 0bf059b commit 1c21e9d

File tree

2 files changed

+41
-18
lines changed

2 files changed

+41
-18
lines changed

notebooks/shapley_explainability/XGBoostGenericShapleyFraudDetection.ipynb

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
},
4040
{
4141
"cell_type": "code",
42-
"execution_count": 1,
42+
"execution_count": 4,
4343
"metadata": {},
4444
"outputs": [],
4545
"source": [
@@ -51,7 +51,7 @@
5151
},
5252
{
5353
"cell_type": "code",
54-
"execution_count": 2,
54+
"execution_count": 5,
5555
"metadata": {},
5656
"outputs": [],
5757
"source": [
@@ -75,7 +75,7 @@
7575
},
7676
{
7777
"cell_type": "code",
78-
"execution_count": 3,
78+
"execution_count": 6,
7979
"metadata": {},
8080
"outputs": [],
8181
"source": [
@@ -228,7 +228,7 @@
228228
},
229229
{
230230
"cell_type": "code",
231-
"execution_count": 8,
231+
"execution_count": 7,
232232
"metadata": {},
233233
"outputs": [],
234234
"source": [
@@ -443,7 +443,7 @@
443443
},
444444
{
445445
"cell_type": "code",
446-
"execution_count": 15,
446+
"execution_count": 1,
447447
"metadata": {},
448448
"outputs": [],
449449
"source": [
@@ -461,27 +461,27 @@
461461
},
462462
{
463463
"cell_type": "code",
464-
"execution_count": 16,
464+
"execution_count": 2,
465465
"metadata": {},
466466
"outputs": [],
467467
"source": [
468468
"with open(\"models/fraud_detection_model.pt\", \"rb\") as f:\n",
469469
" fraud_detection_model_blob = f.read()\n",
470470
"\n",
471-
"with open(\"torch_shapely.py\", \"rb\") as f:\n",
471+
"with open(\"torch_shapley.py\", \"rb\") as f:\n",
472472
" shapely_script = f.read()"
473473
]
474474
},
475475
{
476476
"cell_type": "markdown",
477477
"metadata": {},
478478
"source": [
479-
"We load both movel and script into RedisAI."
479+
"We load both model and script into RedisAI."
480480
]
481481
},
482482
{
483483
"cell_type": "code",
484-
"execution_count": 17,
484+
"execution_count": 3,
485485
"metadata": {},
486486
"outputs": [
487487
{
@@ -490,14 +490,14 @@
490490
"'OK'"
491491
]
492492
},
493-
"execution_count": 17,
493+
"execution_count": 3,
494494
"metadata": {},
495495
"output_type": "execute_result"
496496
}
497497
],
498498
"source": [
499499
"rai.modelstore(\"fraud_detection_model\", \"TORCH\", \"CPU\", fraud_detection_model_blob)\n",
500-
"rai.scriptstore(\"shapely_script\", device='CPU', script=shapely_script, entry_points=[\"shapely_sample\"] )"
500+
"rai.scriptstore(\"shapley_script\", device='CPU', script=shapely_script, entry_points=[\"shapley_sample\"] )"
501501
]
502502
},
503503
{
@@ -509,7 +509,7 @@
509509
},
510510
{
511511
"cell_type": "code",
512-
"execution_count": 18,
512+
"execution_count": 8,
513513
"metadata": {},
514514
"outputs": [
515515
{
@@ -523,7 +523,7 @@
523523
"source": [
524524
"rai.tensorset(\"fraud_input\", X_test_fraud, dtype=\"float\")\n",
525525
"\n",
526-
"rai.scriptexecute(\"shapely_script\", \"shapely_sample\", inputs = [\"fraud_input\"], keys = [\"fraud_detection_model\"], args = [\"20\", \"2\", \"0\"], outputs=[\"fraud_explanations\"])\n",
526+
"rai.scriptexecute(\"shapley_script\", \"shapley_sample\", inputs = [\"fraud_input\"], keys = [\"fraud_detection_model\"], args = [\"20\", \"2\", \"0\"], outputs=[\"fraud_explanations\"])\n",
527527
"\n",
528528
"rai_expl = rai.tensorget(\"fraud_explanations\")\n",
529529
"\n",
@@ -541,16 +541,16 @@
541541
},
542542
{
543543
"cell_type": "code",
544-
"execution_count": 19,
544+
"execution_count": 9,
545545
"metadata": {},
546546
"outputs": [
547547
{
548548
"data": {
549549
"text/plain": [
550-
"<redisai.dag.Dag at 0x7f8a941a52e0>"
550+
"<redisai.dag.Dag at 0x7f80118a6640>"
551551
]
552552
},
553-
"execution_count": 19,
553+
"execution_count": 9,
554554
"metadata": {},
555555
"output_type": "execute_result"
556556
}
@@ -573,7 +573,7 @@
573573
},
574574
{
575575
"cell_type": "code",
576-
"execution_count": 20,
576+
"execution_count": 11,
577577
"metadata": {},
578578
"outputs": [],
579579
"source": [
@@ -634,6 +634,29 @@
634634
"\n",
635635
"print(\"Winning feature: %d\" % winning_feature_redisai_dag)"
636636
]
637+
},
638+
{
639+
"cell_type": "code",
640+
"execution_count": 13,
641+
"metadata": {},
642+
"outputs": [
643+
{
644+
"data": {
645+
"text/plain": [
646+
"array([ 0. , -0.05, 0. , 0.05, 0.2 , 0. , 0. , 0. , 0.05,\n",
647+
" 0. , 0. , 0. , 0.3 , 0. , 0.4 , 0. , 0. , -0.05,\n",
648+
" 0. , 0.05, 0. , 0.05, 0. , -0.05, 0.05, 0. , 0. ,\n",
649+
" 0. , 0. , 0. ])"
650+
]
651+
},
652+
"execution_count": 13,
653+
"metadata": {},
654+
"output_type": "execute_result"
655+
}
656+
],
657+
"source": [
658+
"dag_expl[1]"
659+
]
637660
}
638661
],
639662
"metadata": {

notebooks/shapley_explainability/torch_shapely.py renamed to notebooks/shapley_explainability/torch_shapley.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def index_with_target(x, target:int):
4545

4646
# binary classification - no need for target (output size is 1)
4747
# multiple output (output vector - target specifies the output index to explain.
48-
def shapely_sample(tensors: List[Tensor], keys: List[str], args: List[str]):
48+
def shapley_sample(tensors: List[Tensor], keys: List[str], args: List[str]):
4949
model_key = keys[0]
5050
x = tensors[0]
5151
n_samples = int(args[0])

0 commit comments

Comments
 (0)