Skip to content

Commit

Permalink
- Adding CCA notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
jalammar committed Aug 2, 2021
1 parent 5af0c94 commit efc5950
Show file tree
Hide file tree
Showing 3 changed files with 350 additions and 3 deletions.
337 changes: 337 additions & 0 deletions notebooks/Ecco_CCA_Similarity.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "3d3e8316-6769-44e1-b522-cb4b35fc4541",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/jalammar/ecco/blob/main/notebooks/Ecco_CCA_Similarity.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2aad7c95-bc94-4813-9b9f-dc47fdc442ee",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: ecco in c:\\users\\msx\\pycharmprojects\\ecco\\src (0.0.14)\n",
"Requirement already satisfied: transformers~=4.2 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from ecco) (4.6.1)\n",
"Requirement already satisfied: seaborn~=0.11 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from ecco) (0.11.1)\n",
"Requirement already satisfied: scikit-learn~=0.23 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from ecco) (0.24.2)\n",
"Requirement already satisfied: PyYAML~=5.4 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from ecco) (5.4.1)\n",
"Requirement already satisfied: numpy>=1.13.3 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from scikit-learn~=0.23->ecco) (1.19.5)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from scikit-learn~=0.23->ecco) (2.1.0)\n",
"Requirement already satisfied: joblib>=0.11 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from scikit-learn~=0.23->ecco) (1.0.1)\n",
"Requirement already satisfied: scipy>=0.19.1 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from scikit-learn~=0.23->ecco) (1.6.3)\n",
"Requirement already satisfied: pandas>=0.23 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from seaborn~=0.11->ecco) (1.2.4)\n",
"Requirement already satisfied: matplotlib>=2.2 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from seaborn~=0.11->ecco) (3.3.4)\n",
"Requirement already satisfied: cycler>=0.10 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from matplotlib>=2.2->seaborn~=0.11->ecco) (0.10.0)\n",
"Requirement already satisfied: pillow>=6.2.0 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from matplotlib>=2.2->seaborn~=0.11->ecco) (8.2.0)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in c:\\users\\msx\\appdata\\roaming\\python\\python39\\site-packages (from matplotlib>=2.2->seaborn~=0.11->ecco) (2.4.7)\n",
"Requirement already satisfied: python-dateutil>=2.1 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from matplotlib>=2.2->seaborn~=0.11->ecco) (2.8.1)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from matplotlib>=2.2->seaborn~=0.11->ecco) (1.3.1)\n",
"Requirement already satisfied: six in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from cycler>=0.10->matplotlib>=2.2->seaborn~=0.11->ecco) (1.16.0)\n",
"Requirement already satisfied: pytz>=2017.3 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from pandas>=0.23->seaborn~=0.11->ecco) (2021.1)\n",
"Requirement already satisfied: packaging in c:\\users\\msx\\appdata\\roaming\\python\\python39\\site-packages (from transformers~=4.2->ecco) (20.9)\n",
"Requirement already satisfied: tokenizers<0.11,>=0.10.1 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from transformers~=4.2->ecco) (0.10.3)\n",
"Requirement already satisfied: filelock in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from transformers~=4.2->ecco) (3.0.12)\n",
"Requirement already satisfied: sacremoses in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from transformers~=4.2->ecco) (0.0.45)\n",
"Requirement already satisfied: requests in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from transformers~=4.2->ecco) (2.25.1)\n",
"Requirement already satisfied: tqdm>=4.27 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from transformers~=4.2->ecco) (4.61.0)\n",
"Requirement already satisfied: huggingface-hub==0.0.8 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from transformers~=4.2->ecco) (0.0.8)\n",
"Requirement already satisfied: regex!=2019.12.17 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from transformers~=4.2->ecco) (2021.4.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from requests->transformers~=4.2->ecco) (2.10)\n",
"Requirement already satisfied: chardet<5,>=3.0.2 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from requests->transformers~=4.2->ecco) (4.0.0)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from requests->transformers~=4.2->ecco) (1.26.6)\n",
"Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from requests->transformers~=4.2->ecco) (2021.5.30)\n",
"Requirement already satisfied: click in c:\\users\\msx\\miniconda3\\envs\\ecco\\lib\\site-packages (from sacremoses->transformers~=4.2->ecco) (8.0.1)\n",
"Requirement already satisfied: colorama in c:\\users\\msx\\appdata\\roaming\\python\\python39\\site-packages (from click->sacremoses->transformers~=4.2->ecco) (0.4.4)\n"
]
}
],
"source": [
"!pip install ecco"
]
},
{
"cell_type": "markdown",
"id": "55290353-1c0f-4778-abd8-bb4c09969e1a",
"metadata": {},
"source": [
"Load Ecco and BERT."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "228061f6-2cfc-47ea-9357-6789c81745d1",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']\n",
"- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"import ecco\n",
"lm = ecco.from_pretrained('distilbert-base-uncased', gpu=False)"
]
},
{
"cell_type": "markdown",
"id": "27d1f126-e938-42f6-8e8c-fb97f79a2b74",
"metadata": {},
"source": [
"Let's give BERT a passage of text to proccess"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b1abd436-50bf-4722-b113-73eaa795020d",
"metadata": {},
"outputs": [],
"source": [
"text = '''Now I ask you: what can be expected of man since he is a being endowed with strange qualities? Shower upon him every earthly blessing, drown him in a sea of happiness, so that nothing but bubbles of bliss can be seen on the surface; give him economic prosperity, such that he should have nothing else to do but sleep, eat cakes and busy himself with the continuation of his species, and even then out of sheer ingratitude, sheer spite, man would play you some nasty trick. He would even risk his cakes and would deliberately desire the most fatal rubbish, the most uneconomical absurdity, simply to introduce into all this positive good sense his fatal fantastic element. It is just his fantastic dreams, his vulgar folly that he will desire to retain, simply in order to prove to himself--as though that were so necessary-- that men still are men and not the keys of a piano, which the laws of nature threaten to control so completely that soon one will be able to desire nothing but by the calendar. And that is not all: even if man really were nothing but a piano-key, even if this were proved to him by natural science and mathematics, even then he would not become reasonable, but would purposely do something perverse out of simple ingratitude, simply to gain his point. And if he does not find means he will contrive destruction and chaos, will contrive sufferings of all sorts, only to gain his point! He will launch a curse upon the world, and as only man can curse (it is his privilege, the primary distinction between him and other animals), may be by his curse alone he will attain his object--that is, convince himself that he is a man and not a piano-key!\n",
"'''\n",
"\n",
"inputs = lm.tokenizer([text], return_tensors=\"pt\")\n",
"output = lm(inputs)"
]
},
{
"cell_type": "markdown",
"id": "470b6fa8-1f45-42e6-afa8-ba6fbe576bb4",
"metadata": {},
"source": [
"the `output` variable now contains the result of BERT processing the passge of text. The property `output.hidden_states` contains the hidden states after each layer."
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "1c97a987-c239-4f51-8f39-f4f00dfa9464",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((768, 363), (768, 363), 6)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embed = output.hidden_states[0].detach().numpy()[0,:,:].T\n",
"hidden_state_layer = [layer.detach().numpy()[0,:,:].T for layer in output.hidden_states[1:]]\n",
"embed.shape, hidden_state_layer[0].shape, len(hidden_state_layer)"
]
},
{
"cell_type": "markdown",
"id": "41ee4085-a697-4753-98b8-f930b4abcad0",
"metadata": {},
"source": [
"`embed` now contains the embeddings of the inputs. Its dimensions are (embed_dim, number of tokens). \n",
"`hidden_state_layer` has the outputs of each of the model's 6 layers. The output of each layer is (embed_dim, number of tokens).\n",
"\n",
"This is how to calculate the cka similarity score between the embeddings layer and the output of the first layer:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "9fe8fa2b-39b7-4a04-bf4e-c62cffe3f2ff",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9042735809843326"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from ecco import analysis\n",
"analysis.cka(embed, hidden_state_layer[0])"
]
},
{
"cell_type": "markdown",
"id": "831da973-63bc-4a53-8022-8017ee82af57",
"metadata": {},
"source": [
"When we compare the embeddings with the output of the second layer, we see less similarity"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "d428f519-fb48-402d-8bea-962783fe36de",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.7774273750068427"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"analysis.cka(embed, hidden_state_layer[1])"
]
},
{
"cell_type": "markdown",
"id": "823e071c-ccc2-458e-b792-8d451bd48e41",
"metadata": {},
"source": [
"And so on"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "8621e099-a768-4e17-86e6-6281c4fe4a0f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.6922863613160068"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"analysis.cka(embed, hidden_state_layer[2])"
]
},
{
"cell_type": "markdown",
"id": "7c2de6f2-fe10-4672-ae51-cfdfe2382be2",
"metadata": {},
"source": [
"We can try with `cca`, `svcca` and `pwcca`. But we need to choose a subset of the neurons because these methods require more tokens than neurons (and advise 10x as many tokens as neurons to get a proper similarity score). \n",
"\n",
"Let's compare the similarities of the first 50 neurons."
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "fc3f50a6-eaa0-498e-896a-11d333d7fb5f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CCA - Embed vs. layer 0: 0.8518187635570224\n",
"CCA - Embed vs. layer 1: 0.7220358141619774\n"
]
}
],
"source": [
"print(\"CCA - Embed vs. layer 0:\", analysis.cca(embed[:50,:], hidden_state_layer[0][:50,:]))\n",
"print(\"CCA - Embed vs. layer 1:\", analysis.cca(embed[:50,:], hidden_state_layer[1][:50,:]))"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "ebb5d691-af03-4d0c-9282-c5338671df12",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SVCCA - Embed vs. layer 0: 0.7830642647708996\n",
"SVCCA - Embed vs. layer 1: 0.6833412957583129\n"
]
}
],
"source": [
"print(\"SVCCA - Embed vs. layer 0:\", analysis.svcca(embed[:50,:], hidden_state_layer[0][:50,:]))\n",
"print(\"SVCCA - Embed vs. layer 1:\", analysis.svcca(embed[:50,:], hidden_state_layer[1][:50,:]))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "5b3e47b3-e5a6-47ce-8964-29eb086e29ad",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PWCCA - Embed vs. layer 0: 0.8695735246868949\n",
"PWCCA - Embed vs. layer 1: 0.746195889153883\n"
]
}
],
"source": [
"print(\"PWCCA - Embed vs. layer 0:\", analysis.pwcca(embed[:50,:], hidden_state_layer[0][:50,:]))\n",
"print(\"PWCCA - Embed vs. layer 1:\", analysis.pwcca(embed[:50,:], hidden_state_layer[1][:50,:]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
6 changes: 4 additions & 2 deletions src/ecco/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
def cca(acts1, acts2):
"""
Calculate a similarity score for two activation matrices using Canonical Correlation Analysis (CCA). Returns the
average of all the correlation coefficients. Args: acts1: Activations matrix #1. 2D numPy array. Dimensions: (
neurons, token position) acts2: Activations matrix #2. 2D numPy array. Dimensions: (neurons, token position)
average of all the correlation coefficients.
Args:
acts1: Activations matrix #1. 2D numPy array. Dimensions: (neurons, token position)
acts2: Activations matrix #2. 2D numPy array. Dimensions: (neurons, token position)
Returns:
score: Float between 0 and 1, where 0 means not correlated, 1 means the two activation matrices are linear transformations of each other.
Expand Down
10 changes: 9 additions & 1 deletion src/ecco/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,15 @@ def __call__(self,

def _get_embeddings(self, input_ids):
"""
Takes the token ids of a sequence, returns a matrix of their embeddings.
Get token embeddings and one-hot vector into vocab. It's done via matrix multiplication
so that gradient attribution is available when needed.
Args:
input_ids: Int tensor containing token ids. Of length (sequence length).
Generally returned from the the tokenizer such as
lm.tokenizer(text, return_tensors="pt")['input_ids'][0]
Returns:
inputs_embeds: Embeddings of the tokens. Dimensions are (sequence_len, d_embed)
token_ids_tensor_one_hot: Dimensions are (sequence_len, vocab_size)
"""
# embedding_matrix = self.model.transformer.wte.weight
embedding_matrix = self.model_embeddings
Expand Down

0 comments on commit efc5950

Please sign in to comment.