diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 9a802831e..9c3f446ae 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -240,6 +240,7 @@ jobs: # - "No_Position_Experiment" - "Othello_GPT" - "Patchscopes_Generation_Demo" + - "stable_lm" # - "T5" steps: - uses: actions/checkout@v3 diff --git a/demos/stable_lm.ipynb b/demos/stable_lm.ipynb index bfe623c36..f8fd37b00 100644 --- a/demos/stable_lm.ipynb +++ b/demos/stable_lm.ipynb @@ -60,21 +60,22 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "id": "PXB6xkimoH2h" }, "outputs": [], "source": [ "import torch\n", - "from transformer_lens import HookedTransformer\n", + "from transformer_lens.model_bridge import TransformerBridge\n", "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -355,7 +356,8 @@ "source": [ "# Load the 3 billion parameters version in 16 bits\n", "# You can increase the precision or the size if you have enough GPU RAM available\n", - "model = HookedTransformer.from_pretrained(\"stabilityai/stablelm-tuned-alpha-3b\", torch_dtype=torch.bfloat16, device=device)" + "model = TransformerBridge.boot_transformers(\"stabilityai/stablelm-tuned-alpha-3b\", dtype=dtype, device=device)\n", + "model.enable_compatibility_mode()" ] }, {