|
39 | 39 | }, |
40 | 40 | { |
41 | 41 | "cell_type": "code", |
42 | | - "execution_count": 1, |
| 42 | + "execution_count": 4, |
43 | 43 | "metadata": {}, |
44 | 44 | "outputs": [], |
45 | 45 | "source": [ |
|
51 | 51 | }, |
52 | 52 | { |
53 | 53 | "cell_type": "code", |
54 | | - "execution_count": 2, |
| 54 | + "execution_count": 5, |
55 | 55 | "metadata": {}, |
56 | 56 | "outputs": [], |
57 | 57 | "source": [ |
|
75 | 75 | }, |
76 | 76 | { |
77 | 77 | "cell_type": "code", |
78 | | - "execution_count": 3, |
| 78 | + "execution_count": 6, |
79 | 79 | "metadata": {}, |
80 | 80 | "outputs": [], |
81 | 81 | "source": [ |
|
228 | 228 | }, |
229 | 229 | { |
230 | 230 | "cell_type": "code", |
231 | | - "execution_count": 8, |
| 231 | + "execution_count": 7, |
232 | 232 | "metadata": {}, |
233 | 233 | "outputs": [], |
234 | 234 | "source": [ |
|
443 | 443 | }, |
444 | 444 | { |
445 | 445 | "cell_type": "code", |
446 | | - "execution_count": 15, |
| 446 | + "execution_count": 1, |
447 | 447 | "metadata": {}, |
448 | 448 | "outputs": [], |
449 | 449 | "source": [ |
|
461 | 461 | }, |
462 | 462 | { |
463 | 463 | "cell_type": "code", |
464 | | - "execution_count": 16, |
| 464 | + "execution_count": 2, |
465 | 465 | "metadata": {}, |
466 | 466 | "outputs": [], |
467 | 467 | "source": [ |
468 | 468 | "with open(\"models/fraud_detection_model.pt\", \"rb\") as f:\n", |
469 | 469 | " fraud_detection_model_blob = f.read()\n", |
470 | 470 | "\n", |
471 | | - "with open(\"torch_shapely.py\", \"rb\") as f:\n", |
| 471 | + "with open(\"torch_shapley.py\", \"rb\") as f:\n", |
472 | 472 | " shapely_script = f.read()" |
473 | 473 | ] |
474 | 474 | }, |
475 | 475 | { |
476 | 476 | "cell_type": "markdown", |
477 | 477 | "metadata": {}, |
478 | 478 | "source": [ |
479 | | - "We load both movel and script into RedisAI." |
| 479 | + "We load both model and script into RedisAI." |
480 | 480 | ] |
481 | 481 | }, |
482 | 482 | { |
483 | 483 | "cell_type": "code", |
484 | | - "execution_count": 17, |
| 484 | + "execution_count": 3, |
485 | 485 | "metadata": {}, |
486 | 486 | "outputs": [ |
487 | 487 | { |
|
490 | 490 | "'OK'" |
491 | 491 | ] |
492 | 492 | }, |
493 | | - "execution_count": 17, |
| 493 | + "execution_count": 3, |
494 | 494 | "metadata": {}, |
495 | 495 | "output_type": "execute_result" |
496 | 496 | } |
497 | 497 | ], |
498 | 498 | "source": [ |
499 | 499 | "rai.modelstore(\"fraud_detection_model\", \"TORCH\", \"CPU\", fraud_detection_model_blob)\n", |
500 | | - "rai.scriptstore(\"shapely_script\", device='CPU', script=shapely_script, entry_points=[\"shapely_sample\"] )" |
| 500 | + "rai.scriptstore(\"shapley_script\", device='CPU', script=shapely_script, entry_points=[\"shapley_sample\"] )" |
501 | 501 | ] |
502 | 502 | }, |
503 | 503 | { |
|
509 | 509 | }, |
510 | 510 | { |
511 | 511 | "cell_type": "code", |
512 | | - "execution_count": 18, |
| 512 | + "execution_count": 8, |
513 | 513 | "metadata": {}, |
514 | 514 | "outputs": [ |
515 | 515 | { |
|
523 | 523 | "source": [ |
524 | 524 | "rai.tensorset(\"fraud_input\", X_test_fraud, dtype=\"float\")\n", |
525 | 525 | "\n", |
526 | | - "rai.scriptexecute(\"shapely_script\", \"shapely_sample\", inputs = [\"fraud_input\"], keys = [\"fraud_detection_model\"], args = [\"20\", \"2\", \"0\"], outputs=[\"fraud_explanations\"])\n", |
| 526 | + "rai.scriptexecute(\"shapley_script\", \"shapley_sample\", inputs = [\"fraud_input\"], keys = [\"fraud_detection_model\"], args = [\"20\", \"2\", \"0\"], outputs=[\"fraud_explanations\"])\n", |
527 | 527 | "\n", |
528 | 528 | "rai_expl = rai.tensorget(\"fraud_explanations\")\n", |
529 | 529 | "\n", |
|
541 | 541 | }, |
542 | 542 | { |
543 | 543 | "cell_type": "code", |
544 | | - "execution_count": 19, |
| 544 | + "execution_count": 9, |
545 | 545 | "metadata": {}, |
546 | 546 | "outputs": [ |
547 | 547 | { |
548 | 548 | "data": { |
549 | 549 | "text/plain": [ |
550 | | - "<redisai.dag.Dag at 0x7f8a941a52e0>" |
| 550 | + "<redisai.dag.Dag at 0x7f80118a6640>" |
551 | 551 | ] |
552 | 552 | }, |
553 | | - "execution_count": 19, |
| 553 | + "execution_count": 9, |
554 | 554 | "metadata": {}, |
555 | 555 | "output_type": "execute_result" |
556 | 556 | } |
|
573 | 573 | }, |
574 | 574 | { |
575 | 575 | "cell_type": "code", |
576 | | - "execution_count": 20, |
| 576 | + "execution_count": 11, |
577 | 577 | "metadata": {}, |
578 | 578 | "outputs": [], |
579 | 579 | "source": [ |
|
634 | 634 | "\n", |
635 | 635 | "print(\"Winning feature: %d\" % winning_feature_redisai_dag)" |
636 | 636 | ] |
| 637 | + }, |
| 638 | + { |
| 639 | + "cell_type": "code", |
| 640 | + "execution_count": 13, |
| 641 | + "metadata": {}, |
| 642 | + "outputs": [ |
| 643 | + { |
| 644 | + "data": { |
| 645 | + "text/plain": [ |
| 646 | + "array([ 0. , -0.05, 0. , 0.05, 0.2 , 0. , 0. , 0. , 0.05,\n", |
| 647 | + " 0. , 0. , 0. , 0.3 , 0. , 0.4 , 0. , 0. , -0.05,\n", |
| 648 | + " 0. , 0.05, 0. , 0.05, 0. , -0.05, 0.05, 0. , 0. ,\n", |
| 649 | + " 0. , 0. , 0. ])" |
| 650 | + ] |
| 651 | + }, |
| 652 | + "execution_count": 13, |
| 653 | + "metadata": {}, |
| 654 | + "output_type": "execute_result" |
| 655 | + } |
| 656 | + ], |
| 657 | + "source": [ |
| 658 | + "dag_expl[1]" |
| 659 | + ] |
637 | 660 | } |
638 | 661 | ], |
639 | 662 | "metadata": { |
|
0 commit comments