From 97c3a172496e837fd4d2a1d5bf547c4173ba6c9c Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 10:45:27 -0800 Subject: [PATCH 01/26] set use_cache=False in forward pass since we dont use past --- src/ecco/lm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 526ab37..d3b9e5d 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -97,7 +97,7 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, """ inputs_embeds, token_ids_tensor_one_hot = self._get_embeddings(input_ids) - output = self.model(inputs_embeds=inputs_embeds, return_dict=True) + output = self.model(inputs_embeds=inputs_embeds, return_dict=True, use_cache=False) predict = output[0] past = output[1] # We're not using past because by presenting all the past tokens at every # step, we can get feature importance attribution. Let me know if it can be done with past @@ -345,9 +345,9 @@ def display_token(self, viz_id, token_id, position): 'type': 'output' } js = f""" - // We don't really need these require scripts. But this is to avert + // We don't really need these require scripts. But this is to avert //this code from running before display_input_sequence which DOES require external files - requirejs(['basic', 'ecco'], function(basic, ecco){{ + requirejs(['basic', 'ecco'], function(basic, ecco){{ console.log('addToken viz_id', '{viz_id}'); window.ecco['{viz_id}'].addToken({json.dumps(token)}) window.ecco['{viz_id}'].redraw() From 9d461ac4c5c2ff069c191f170dd16f7f2c9b2ffc Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 11:08:24 -0800 Subject: [PATCH 02/26] access output entries by str key rather than int index --- src/ecco/lm.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index d3b9e5d..411b0f9 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -98,9 +98,7 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, inputs_embeds, token_ids_tensor_one_hot = self._get_embeddings(input_ids) output = self.model(inputs_embeds=inputs_embeds, return_dict=True, use_cache=False) - predict = output[0] - past = output[1] # We're not using past because by presenting all the past tokens at every - # step, we can get feature importance attribution. Let me know if it can be done with past + predict = output.logits scores = predict[-1, :] @@ -125,7 +123,7 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, self.attributions['grad_x_input'] = [] self.attributions['grad_x_input'].append(grad_x_input.cpu().detach().numpy()) - return prediction_id, output, past + return prediction_id, output def generate(self, input_str: str, max_length: Optional[int] = 128, temperature: Optional[float] = None, @@ -163,13 +161,13 @@ def generate(self, input_str: str, max_length: Optional[int] = 128, viz_id = self.display_input_sequence(input_ids) while cur_len < max_length: - output_token_id, output, past = self._generate_token(input_ids, - past, - # Note, this is not currently used - temperature=temperature, - top_k=top_k, top_p=top_p, - do_sample=do_sample, - attribution_flag=attribution) + output_token_id, output = self._generate_token(input_ids, + past, + # Note, this is not currently used + temperature=temperature, + top_k=top_k, top_p=top_p, + do_sample=do_sample, + attribution_flag=attribution) if (get_model_output): outputs.append(output) @@ -189,16 +187,14 @@ def generate(self, input_str: str, max_length: Optional[int] = 128, if activations_dict != {}: self.activations = activations_dict_to_array(activations_dict) - hidden_states = output[2] + hidden_states = output.hidden_states tokens = [] for i in input_ids: token = self.tokenizer.decode([i]) tokens.append(token) attributions = self.attributions - attn = None - if len(output) == 4: - attn = output[-1] + attn = getattr(output, "attentions", None) return OutputSeq(**{'tokenizer': self.tokenizer, 'token_ids': input_ids, 'n_input_tokens': n_input_tokens, From 71c724b759e27570a431ac615c6a471cb879fbad Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 11:52:49 -0800 Subject: [PATCH 03/26] free output.logits --- src/ecco/lm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 411b0f9..d4dc470 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -123,6 +123,7 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, self.attributions['grad_x_input'] = [] self.attributions['grad_x_input'].append(grad_x_input.cpu().detach().numpy()) + del output.logits # free tensor memory we won't use again return prediction_id, output def generate(self, input_str: str, max_length: Optional[int] = 128, From 924eed93dd26b4aa0e09b3605dde52542c8304aa Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 11:54:46 -0800 Subject: [PATCH 04/26] free grad memory for hidden states --- src/ecco/lm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index d4dc470..d0daa04 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -124,6 +124,8 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, self.attributions['grad_x_input'].append(grad_x_input.cpu().detach().numpy()) del output.logits # free tensor memory we won't use again + output.hidden_states = tuple([h.detach() for h in output.hidden_states]) # don't need grads here + return prediction_id, output def generate(self, input_str: str, max_length: Optional[int] = 128, From 3842e23c385b2726c80150494a1bae32a4cda009 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 12:03:06 -0800 Subject: [PATCH 05/26] put hidden states on cpu as we generate --- src/ecco/lm.py | 5 ++++- src/ecco/output.py | 12 ++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index d0daa04..83a207d 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -124,7 +124,10 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, self.attributions['grad_x_input'].append(grad_x_input.cpu().detach().numpy()) del output.logits # free tensor memory we won't use again - output.hidden_states = tuple([h.detach() for h in output.hidden_states]) # don't need grads here + + # detach(): don't need grads here + # cpu(): not used by GPU during generation; may lead to GPU OOM if left on GPU during long generations + output.hidden_states = tuple([h.cpu().detach() for h in output.hidden_states]) return prediction_id, output diff --git a/src/ecco/output.py b/src/ecco/output.py index 909839f..17d2f00 100644 --- a/src/ecco/output.py +++ b/src/ecco/output.py @@ -173,10 +173,10 @@ def saliency(self, attr_method: Optional[str] = 'grad_x_input', style="minimal", data: {data}, preset: 'viridis' }}) - + window.ecco[viz_id].init(); window.ecco[viz_id].selectFirstToken(); - + }}, function (err) {{ console.log(err); }})""" @@ -252,7 +252,7 @@ def layer_predictions(self, position: int = 0, topk: Optional[int] = 10, layer: # print(h.shape) hidden_state = h[position - 1] # Use lm_head to project the layer's hidden state to output vocabulary - logits = self.lm_head(hidden_state) + logits = self.lm_head(self.to(hidden_state)) softmax = F.softmax(logits, dim=-1) sorted_softmax = self.to(torch.argsort(softmax)) @@ -283,7 +283,7 @@ def layer_predictions(self, position: int = 0, topk: Optional[int] = 10, layer: js = f""" requirejs(['basic', 'ecco'], function(basic, ecco){{ const viz_id = basic.init() - + let pred = new ecco.LayerPredictions({{ parentDiv: viz_id, @@ -321,7 +321,7 @@ def rankings(self, **kwargs): # print('hidden state layer', i, 'position', self.n_input_tokens-1+j) # Project hidden state to vocabulary # (after debugging pain: ensure input is on GPU, if appropriate) - logits = self.lm_head(hidden_state) + logits = self.lm_head(self.to(hidden_state)) # logits = self.lm_head(torch.tensor(hidden_state)) # Sort by score (ascending) sorted = torch.argsort(logits) @@ -380,7 +380,7 @@ def rankings_watch(self, watch: List[int] = None, position: int = -1, **kwargs): hidden_state = level[position] # Project hidden state to vocabulary # (after debugging pain: ensure input is on GPU, if appropriate) - logits = self.lm_head(hidden_state) + logits = self.lm_head(self.to(hidden_state)) # logits = lmhead(torch.tensor(hidden_state)) # Sort by score (ascending) sorted = torch.argsort(logits) From 1a759fa2f4dfb178066297608082858ce4f499f1 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 13:17:30 -0800 Subject: [PATCH 06/26] fix edge case where squeeze would remove position dim if there is only one position --- src/ecco/lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 83a207d..beaad46 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -43,7 +43,7 @@ def activations_dict_to_array(activations_dict): for i in range(len(activations_dict)): activations.append(activations_dict[i]) - activations = np.squeeze(np.array(activations)) + activations = np.concatenate(activations, axis=0) return np.swapaxes(activations, 1, 2) From 68d0fbcc98909ca544b9cf8cdf98132db7357fd7 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 13:23:13 -0800 Subject: [PATCH 07/26] retain position dim in scores arg for transformers.generation_utils.top_k_top_p_filtering (on transformers v3.4.0 i could not get top_p to work without this change) --- src/ecco/lm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index beaad46..0b33547 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -30,6 +30,7 @@ def sample_output_token(scores, do_sample, temperature, top_k, top_p): else: # Greedy decoding prediction_id = torch.argmax(scores, dim=-1) + prediction_id = prediction_id.squeeze() return prediction_id @@ -100,7 +101,7 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, output = self.model(inputs_embeds=inputs_embeds, return_dict=True, use_cache=False) predict = output.logits - scores = predict[-1, :] + scores = predict[-1:, :] prediction_id = sample_output_token(scores, do_sample, temperature, top_k, top_p) # Print the sampled token From 6f8d961b07401121ec1791f8cd0ba18f8090fde0 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 13:31:52 -0800 Subject: [PATCH 08/26] allow not returning hidden states --- src/ecco/__init__.py | 4 ++-- src/ecco/lm.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/ecco/__init__.py b/src/ecco/__init__.py index 046d1cd..e00e788 100644 --- a/src/ecco/__init__.py +++ b/src/ecco/__init__.py @@ -2,14 +2,14 @@ from ecco.lm import LM, MockGPT, MockGPTTokenizer from transformers import AutoTokenizer, AutoModelForCausalLM -def from_pretrained(hf_model_id, activations=False, attention=False): +def from_pretrained(hf_model_id, activations=False, attention=False, hidden_states=True): if hf_model_id == "mockGPT": tokenizer = MockGPTTokenizer() model = MockGPT() else: tokenizer = AutoTokenizer.from_pretrained(hf_model_id) model = AutoModelForCausalLM.from_pretrained(hf_model_id, - output_hidden_states=True, + output_hidden_states=hidden_states, output_attentions=attention) if activations: lm = LM(model, tokenizer, collect_activations_flag=True) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 0b33547..08879bf 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -128,7 +128,8 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, # detach(): don't need grads here # cpu(): not used by GPU during generation; may lead to GPU OOM if left on GPU during long generations - output.hidden_states = tuple([h.cpu().detach() for h in output.hidden_states]) + if hasattr(output, "hidden_states"): + output.hidden_states = tuple([h.cpu().detach() for h in output.hidden_states]) return prediction_id, output @@ -194,7 +195,7 @@ def generate(self, input_str: str, max_length: Optional[int] = 128, if activations_dict != {}: self.activations = activations_dict_to_array(activations_dict) - hidden_states = output.hidden_states + hidden_states = getattr(output, "hidden_states", None) tokens = [] for i in input_ids: token = self.tokenizer.decode([i]) From 3e437c15a91b43d024422d632bfc9a4fe907db2e Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 13:43:58 -0800 Subject: [PATCH 09/26] allow LM to collect activations for only a specified subset of layers --- src/ecco/lm.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 08879bf..1f7d852 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -55,7 +55,9 @@ class LM(object): def __init__(self, model, tokenizer, collect_activations_flag=False, - collect_gen_activations_flag=False): + collect_gen_activations_flag=False, + collect_activations_layer_nums=None, # None --> collect for all layers + ): self.model = model if torch.cuda.is_available(): self.model = model.to('cuda') @@ -69,6 +71,7 @@ def __init__(self, model, tokenizer, # Neuron Activation self.collect_activations_flag = collect_activations_flag self.collect_gen_activations_flag = collect_gen_activations_flag + self.collect_activations_layer_nums = collect_activations_layer_nums self._hooks = {} self._reset() self._attach_hooks(self.model) @@ -260,13 +263,16 @@ def _get_activations_hook(self, name: str, input_): # Extract the number of the layer from the name layer_number = int(name.split('.')[2]) - if layer_number not in self._all_activations_dict: - self._all_activations_dict[layer_number] = [0] + collecting_this_layer = (self.collect_activations_layer_nums is None) or (layer_number in self.collect_activations_layer_nums) - # Overwrite the previous step activations. This collects all activations in the last step - # Assuming all input tokens are presented as input, no "past" - # The inputs to c_proj already pass through the gelu activation function - self._all_activations_dict[layer_number][0] = input_[0][0].detach().cpu().numpy() + if collecting_this_layer: + if layer_number not in self._all_activations_dict: + self._all_activations_dict[layer_number] = [0] + + # Overwrite the previous step activations. This collects all activations in the last step + # Assuming all input tokens are presented as input, no "past" + # The inputs to c_proj already pass through the gelu activation function + self._all_activations_dict[layer_number][0] = input_[0][0].detach().cpu().numpy() def _get_generation_activations_hook(self, name: str, input_): """ @@ -277,12 +283,15 @@ def _get_generation_activations_hook(self, name: str, input_): # Extract the number of the layer from the name layer_number = int(name.split('.')[2]) - if layer_number not in self._generation_activations_dict: - self._generation_activations_dict[layer_number] = [] + collecting_this_layer = (self.collect_activations_layer_nums is None) or (layer_number in self.collect_activations_layer_nums) + + if collecting_this_layer: + if layer_number not in self._generation_activations_dict: + self._generation_activations_dict[layer_number] = [] - # Accumulate in dict - # The inputs to c_proj already pass through the gelu activation function - self._generation_activations_dict[layer_number].append(input_[0][0][-1].detach().cpu().numpy()) + # Accumulate in dict + # The inputs to c_proj already pass through the gelu activation function + self._generation_activations_dict[layer_number].append(input_[0][0][-1].detach().cpu().numpy()) def _inhibit_neurons_hook(self, name: str, input_tensor): """ From 93135f3a8a7f43b8c4244ea8b47e4a08367bf7ea Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 13:55:01 -0800 Subject: [PATCH 10/26] bugfix for case with no hidden_states --- src/ecco/lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 1f7d852..8cefe28 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -131,7 +131,7 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, # detach(): don't need grads here # cpu(): not used by GPU during generation; may lead to GPU OOM if left on GPU during long generations - if hasattr(output, "hidden_states"): + if getattr(output, "hidden_states", None) is not None: output.hidden_states = tuple([h.cpu().detach() for h in output.hidden_states]) return prediction_id, output From d6fd1b5848eb74c664ae7c8403de3c850ac700f7 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 13:55:27 -0800 Subject: [PATCH 11/26] expose hidden_states and activations_layer_nums in from_pretrained construction route --- src/ecco/__init__.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/ecco/__init__.py b/src/ecco/__init__.py index e00e788..bd5e497 100644 --- a/src/ecco/__init__.py +++ b/src/ecco/__init__.py @@ -2,7 +2,12 @@ from ecco.lm import LM, MockGPT, MockGPTTokenizer from transformers import AutoTokenizer, AutoModelForCausalLM -def from_pretrained(hf_model_id, activations=False, attention=False, hidden_states=True): +def from_pretrained(hf_model_id, + activations=False, + attention=False, + hidden_states=True, + activations_layer_nums=None, + ): if hf_model_id == "mockGPT": tokenizer = MockGPTTokenizer() model = MockGPT() @@ -11,9 +16,9 @@ def from_pretrained(hf_model_id, activations=False, attention=False, hidden_stat model = AutoModelForCausalLM.from_pretrained(hf_model_id, output_hidden_states=hidden_states, output_attentions=attention) - if activations: - lm = LM(model, tokenizer, collect_activations_flag=True) - return lm - else: - lm = LM(model, tokenizer) - return lm + + lm_kwargs = { + 'collect_activations_flag': activations, + 'collect_activations_layer_nums': activations_layer_nums} + lm = LM(model, tokenizer, **lm_kwargs) + return lm From f5bb72024914a99c3e729dbcc27930a37ccacdfc Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 14:01:29 -0800 Subject: [PATCH 12/26] undo some spurious whitespace changes my editor made automatically --- src/ecco/lm.py | 4 ++-- src/ecco/output.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 8cefe28..a053ce1 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -358,9 +358,9 @@ def display_token(self, viz_id, token_id, position): 'type': 'output' } js = f""" - // We don't really need these require scripts. But this is to avert + // We don't really need these require scripts. But this is to avert //this code from running before display_input_sequence which DOES require external files - requirejs(['basic', 'ecco'], function(basic, ecco){{ + requirejs(['basic', 'ecco'], function(basic, ecco){{ console.log('addToken viz_id', '{viz_id}'); window.ecco['{viz_id}'].addToken({json.dumps(token)}) window.ecco['{viz_id}'].redraw() diff --git a/src/ecco/output.py b/src/ecco/output.py index 17d2f00..57ad70b 100644 --- a/src/ecco/output.py +++ b/src/ecco/output.py @@ -173,10 +173,10 @@ def saliency(self, attr_method: Optional[str] = 'grad_x_input', style="minimal", data: {data}, preset: 'viridis' }}) - + window.ecco[viz_id].init(); window.ecco[viz_id].selectFirstToken(); - + }}, function (err) {{ console.log(err); }})""" @@ -283,7 +283,7 @@ def layer_predictions(self, position: int = 0, topk: Optional[int] = 10, layer: js = f""" requirejs(['basic', 'ecco'], function(basic, ecco){{ const viz_id = basic.init() - + let pred = new ecco.LayerPredictions({{ parentDiv: viz_id, From 14f3b0054ca5c8f635acc1afa5938fcde5501189 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 14:14:12 -0800 Subject: [PATCH 13/26] fix layer indexing in NMF when using collect_activations_layer_nums to subset layers --- src/ecco/lm.py | 1 + src/ecco/output.py | 24 +++++++++++++++++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index a053ce1..4673bf1 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -216,6 +216,7 @@ def generate(self, input_str: str, max_length: Optional[int] = 128, 'model_outputs': outputs, 'attribution': attributions, 'activations': self.activations, + 'collect_activations_layer_nums': self.collect_activations_layer_nums, 'lm_head': self.model.lm_head, 'device': self.device}) diff --git a/src/ecco/output.py b/src/ecco/output.py index 57ad70b..de393bf 100644 --- a/src/ecco/output.py +++ b/src/ecco/output.py @@ -23,6 +23,7 @@ def __init__(self, attribution=None, activations=None, activations_type=None, + collect_activations_layer_nums=None, attention=None, model_outputs=None, lm_head=None, @@ -36,6 +37,7 @@ def __init__(self, self.attribution = attribution self.activations = activations self.activations_type = activations_type + self.collect_activations_layer_nums = collect_activations_layer_nums self.model_outputs = model_outputs self.attention_values = attention self.lm_head = lm_head @@ -173,10 +175,10 @@ def saliency(self, attr_method: Optional[str] = 'grad_x_input', style="minimal", data: {data}, preset: 'viridis' }}) - + window.ecco[viz_id].init(); window.ecco[viz_id].selectFirstToken(); - + }}, function (err) {{ console.log(err); }})""" @@ -283,7 +285,7 @@ def layer_predictions(self, position: int = 0, topk: Optional[int] = 10, layer: js = f""" requirejs(['basic', 'ecco'], function(basic, ecco){{ const viz_id = basic.init() - + let pred = new ecco.LayerPredictions({{ parentDiv: viz_id, @@ -424,7 +426,9 @@ def run_nmf(self, **kwargs): n_input_tokens=self.n_input_tokens, token_ids=self.token_ids, _path=self._path, - tokens=self.tokens, **kwargs) + tokens=self.tokens, + collect_activations_layer_nums=self.collect_activations_layer_nums, + **kwargs) def attention(self, attention_values=None, layer=0, **kwargs): @@ -490,6 +494,7 @@ def __init__(self, activations: np.ndarray, # from_layer: Optional[int] = None, # to_layer: Optional[int] = None, tokens: Optional[List[str]] = None, + collect_activations_layer_nums: Optional[List[int]]=None, **kwargs): self._path = _path self.token_ids = token_ids @@ -497,6 +502,13 @@ def __init__(self, activations: np.ndarray, from_layer = kwargs['from_layer'] if 'from_layer' in kwargs else None to_layer = kwargs['to_layer'] if 'to_layer' in kwargs else None + + if collect_activations_layer_nums is None: + collect_activations_layer_nums = list(range(activations.shape[0])) + + layer_nums_to_row_ixs = {layer_num: i + for i, layer_num in enumerate(collect_activations_layer_nums)} + if len(activations.shape) != 3: raise ValueError(f"The 'activations' parameter should have three dimensions: (layers, neurons, positions). " f"Supplied dimensions: {activations.shape}", 'activations') @@ -515,7 +527,9 @@ def __init__(self, activations: np.ndarray, from_layer = 0 to_layer = activations.shape[0] - merged_act = np.concatenate(activations[from_layer: to_layer], axis=0) + row_ixs = [layer_nums_to_row_ixs[layer_num] for layer_num in range(from_layer, to_layer)] + activation_rows = [activations[row_ix] for row_ix in row_ixs] + merged_act = np.concatenate(activation_rows, axis=0) activations = np.expand_dims(merged_act, axis=0) self.tokens = tokens From da556a598918307401cb41f4df211c483021b67a Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 14:22:41 -0800 Subject: [PATCH 14/26] clearer message on missing layer activation --- src/ecco/output.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ecco/output.py b/src/ecco/output.py index de393bf..cbcae59 100644 --- a/src/ecco/output.py +++ b/src/ecco/output.py @@ -527,7 +527,11 @@ def __init__(self, activations: np.ndarray, from_layer = 0 to_layer = activations.shape[0] - row_ixs = [layer_nums_to_row_ixs[layer_num] for layer_num in range(from_layer, to_layer)] + layer_nums = list(range(from_layer, to_layer)) + if any([num not in layer_nums_to_row_ixs for num in layer_nums]): + available = sorted(layer_nums_to_row_ixs.keys()) + raise ValueError(f"Not all layers between from_layer ({from_layer}) and to_layer ({to_layer}) have recorded activations. Layers with recorded activations are: {available}") + row_ixs = [layer_nums_to_row_ixs[layer_num] for layer_num in layer_nums] activation_rows = [activations[row_ix] for row_ix in row_ixs] merged_act = np.concatenate(activation_rows, axis=0) activations = np.expand_dims(merged_act, axis=0) From c057447f56a2fd6a62fb7c9ce0ee190d4c47bedc Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 14:24:41 -0800 Subject: [PATCH 15/26] bugfix --- src/ecco/lm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 4673bf1..0380620 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -41,7 +41,7 @@ def _one_hot(token_ids, vocab_size): def activations_dict_to_array(activations_dict): # print(activations_dict[0].shape) activations = [] - for i in range(len(activations_dict)): + for i in sorted(activations_dict.keys()): activations.append(activations_dict[i]) activations = np.concatenate(activations, axis=0) @@ -359,9 +359,9 @@ def display_token(self, viz_id, token_id, position): 'type': 'output' } js = f""" - // We don't really need these require scripts. But this is to avert + // We don't really need these require scripts. But this is to avert //this code from running before display_input_sequence which DOES require external files - requirejs(['basic', 'ecco'], function(basic, ecco){{ + requirejs(['basic', 'ecco'], function(basic, ecco){{ console.log('addToken viz_id', '{viz_id}'); window.ecco['{viz_id}'].addToken({json.dumps(token)}) window.ecco['{viz_id}'].redraw() From 63877c8e570e6865f213dcf5b7a4425a5ff92b56 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 14:33:09 -0800 Subject: [PATCH 16/26] reduce memory leakage (?) in saliency by avoiding retain_graph --- src/ecco/attribution.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/ecco/attribution.py b/src/ecco/attribution.py index 64bf0d8..f36dd7f 100644 --- a/src/ecco/attribution.py +++ b/src/ecco/attribution.py @@ -4,19 +4,19 @@ def saliency(prediction_logit, token_ids_tensor_one_hot, norm=True): # Back-propegate the gradient from the selected output-logit - prediction_logit.backward(retain_graph=True) + grad = torch.autograd.grad(prediction_logit, token_ids_tensor_one_hot)[0] # token_ids_tensor_one_hot.grad is the gradient propegated to ever embedding dimension of # the input tokens. if norm: # norm calculates a scalar value (L2 Norm) - token_importance_raw = torch.norm(token_ids_tensor_one_hot.grad, dim=1) + token_importance_raw = torch.norm(grad, dim=1) # print('token_importance_raw', token_ids_tensor_one_hot.grad.shape, # np.count_nonzero(token_ids_tensor_one_hot.detach().numpy(), axis=1)) # Normalize the values so they add up to 1 token_importance = token_importance_raw / torch.sum(token_importance_raw) else: - token_importance = torch.sum(token_ids_tensor_one_hot.grad, dim=1) # Only one value, all others are zero + token_importance = torch.sum(grad, dim=1) # Only one value, all others are zero token_ids_tensor_one_hot.grad.data.zero_() return token_importance @@ -26,7 +26,6 @@ def saliency_on_d_embeddings(prediction_logit, inputs_embeds, aggregation="L2"): inputs_embeds.retain_grad() # Back-propegate the gradient from the selected output-logit - prediction_logit.backward(retain_graph=True) # inputs_embeds.grad # token_ids_tensor_one_hot.grad is the gradient propegated to ever embedding dimension of @@ -50,13 +49,9 @@ def saliency_on_d_embeddings(prediction_logit, inputs_embeds, aggregation="L2"): def gradient_x_inputs_attribution(prediction_logit, inputs_embeds): - inputs_embeds.retain_grad() # back-prop gradient - prediction_logit.backward(retain_graph=True) - grad = inputs_embeds.grad - # This should be equivalent to - # grad = torch.autograd.grad(prediction_logit, inputs_embeds)[0] + grad = torch.autograd.grad(prediction_logit, inputs_embeds)[0] # Grad X Input grad_x_input = grad * inputs_embeds From 9f5364d08a9d316639ce0f14e4f87735b352b26b Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 14:34:10 -0800 Subject: [PATCH 17/26] fix --- src/ecco/attribution.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/ecco/attribution.py b/src/ecco/attribution.py index f36dd7f..2e4282f 100644 --- a/src/ecco/attribution.py +++ b/src/ecco/attribution.py @@ -44,7 +44,6 @@ def saliency_on_d_embeddings(prediction_logit, inputs_embeds, aggregation="L2"): token_importance_raw = torch.mean(inputs_embeds.grad, dim=1) token_importance = token_importance_raw # Hmmm, how to normalize if it includes negative values - inputs_embeds.grad.data.zero_() return token_importance @@ -62,7 +61,4 @@ def gradient_x_inputs_attribution(prediction_logit, inputs_embeds): # Normalize so we can show scores as percentages token_importance_normalized = feature_importance / torch.sum(feature_importance) - # Zero the gradient for the tensor so next backward() calls don't have - # gradients accumulating - inputs_embeds.grad.data.zero_() return token_importance_normalized From 1c2b5e60d73a5279c8e30dc92b522fe897419b54 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 14:44:50 -0800 Subject: [PATCH 18/26] fix --- src/ecco/attribution.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ecco/attribution.py b/src/ecco/attribution.py index 2e4282f..71989a9 100644 --- a/src/ecco/attribution.py +++ b/src/ecco/attribution.py @@ -18,7 +18,6 @@ def saliency(prediction_logit, token_ids_tensor_one_hot, norm=True): else: token_importance = torch.sum(grad, dim=1) # Only one value, all others are zero - token_ids_tensor_one_hot.grad.data.zero_() return token_importance From d121a69795fa1f3c8cb58db59ca3965be538a855 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 14:50:31 -0800 Subject: [PATCH 19/26] undo changes that did not work as i expected --- src/ecco/attribution.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/ecco/attribution.py b/src/ecco/attribution.py index 71989a9..64bf0d8 100644 --- a/src/ecco/attribution.py +++ b/src/ecco/attribution.py @@ -4,20 +4,21 @@ def saliency(prediction_logit, token_ids_tensor_one_hot, norm=True): # Back-propegate the gradient from the selected output-logit - grad = torch.autograd.grad(prediction_logit, token_ids_tensor_one_hot)[0] + prediction_logit.backward(retain_graph=True) # token_ids_tensor_one_hot.grad is the gradient propegated to ever embedding dimension of # the input tokens. if norm: # norm calculates a scalar value (L2 Norm) - token_importance_raw = torch.norm(grad, dim=1) + token_importance_raw = torch.norm(token_ids_tensor_one_hot.grad, dim=1) # print('token_importance_raw', token_ids_tensor_one_hot.grad.shape, # np.count_nonzero(token_ids_tensor_one_hot.detach().numpy(), axis=1)) # Normalize the values so they add up to 1 token_importance = token_importance_raw / torch.sum(token_importance_raw) else: - token_importance = torch.sum(grad, dim=1) # Only one value, all others are zero + token_importance = torch.sum(token_ids_tensor_one_hot.grad, dim=1) # Only one value, all others are zero + token_ids_tensor_one_hot.grad.data.zero_() return token_importance @@ -25,6 +26,7 @@ def saliency_on_d_embeddings(prediction_logit, inputs_embeds, aggregation="L2"): inputs_embeds.retain_grad() # Back-propegate the gradient from the selected output-logit + prediction_logit.backward(retain_graph=True) # inputs_embeds.grad # token_ids_tensor_one_hot.grad is the gradient propegated to ever embedding dimension of @@ -43,13 +45,18 @@ def saliency_on_d_embeddings(prediction_logit, inputs_embeds, aggregation="L2"): token_importance_raw = torch.mean(inputs_embeds.grad, dim=1) token_importance = token_importance_raw # Hmmm, how to normalize if it includes negative values + inputs_embeds.grad.data.zero_() return token_importance def gradient_x_inputs_attribution(prediction_logit, inputs_embeds): + inputs_embeds.retain_grad() # back-prop gradient - grad = torch.autograd.grad(prediction_logit, inputs_embeds)[0] + prediction_logit.backward(retain_graph=True) + grad = inputs_embeds.grad + # This should be equivalent to + # grad = torch.autograd.grad(prediction_logit, inputs_embeds)[0] # Grad X Input grad_x_input = grad * inputs_embeds @@ -60,4 +67,7 @@ def gradient_x_inputs_attribution(prediction_logit, inputs_embeds): # Normalize so we can show scores as percentages token_importance_normalized = feature_importance / torch.sum(feature_importance) + # Zero the gradient for the tensor so next backward() calls don't have + # gradients accumulating + inputs_embeds.grad.data.zero_() return token_importance_normalized From 5e7de0afdbf5c77082717e40d4a961fb6e0444ac Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 14:56:15 -0800 Subject: [PATCH 20/26] 2nd attempt at not persisting backward graph in saliency --- src/ecco/attribution.py | 30 ++++++++++++++++++++++++------ src/ecco/lm.py | 9 ++++----- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/ecco/attribution.py b/src/ecco/attribution.py index 64bf0d8..7efaa6c 100644 --- a/src/ecco/attribution.py +++ b/src/ecco/attribution.py @@ -2,9 +2,9 @@ import numpy as np -def saliency(prediction_logit, token_ids_tensor_one_hot, norm=True): +def saliency(prediction_logit, token_ids_tensor_one_hot, norm=True, retain_graph=True): # Back-propegate the gradient from the selected output-logit - prediction_logit.backward(retain_graph=True) + prediction_logit.backward(retain_graph=retain_graph) # token_ids_tensor_one_hot.grad is the gradient propegated to ever embedding dimension of # the input tokens. @@ -22,11 +22,11 @@ def saliency(prediction_logit, token_ids_tensor_one_hot, norm=True): return token_importance -def saliency_on_d_embeddings(prediction_logit, inputs_embeds, aggregation="L2"): +def saliency_on_d_embeddings(prediction_logit, inputs_embeds, aggregation="L2", retain_graph=True): inputs_embeds.retain_grad() # Back-propegate the gradient from the selected output-logit - prediction_logit.backward(retain_graph=True) + prediction_logit.backward(retain_graph=retain_graph) # inputs_embeds.grad # token_ids_tensor_one_hot.grad is the gradient propegated to ever embedding dimension of @@ -49,11 +49,11 @@ def saliency_on_d_embeddings(prediction_logit, inputs_embeds, aggregation="L2"): return token_importance -def gradient_x_inputs_attribution(prediction_logit, inputs_embeds): +def gradient_x_inputs_attribution(prediction_logit, inputs_embeds, retain_graph=True): inputs_embeds.retain_grad() # back-prop gradient - prediction_logit.backward(retain_graph=True) + prediction_logit.backward(retain_graph=retain_graph) grad = inputs_embeds.grad # This should be equivalent to # grad = torch.autograd.grad(prediction_logit, inputs_embeds)[0] @@ -71,3 +71,21 @@ def gradient_x_inputs_attribution(prediction_logit, inputs_embeds): # gradients accumulating inputs_embeds.grad.data.zero_() return token_importance_normalized + +def compute_saliency_scores(prediction_logit, + token_ids_tensor_one_hot, + inputs_embeds, + gradient_kwargs={}, + gradient_x_input_kwargs={}, + ): + results = {} + + results['gradient'] = saliency(prediction_logit, + token_ids_tensor_one_hot, + retain_graph=True, + **gradient_kwargs) + results['grad_x_input'] = gradient_x_inputs_attribution(prediction_logit, + inputs_embeds, + retain_graph=False, + **gradient_x_input_kwargs) + return results diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 0380620..5de82ce 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -116,16 +116,15 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, prediction_logit = predict[inputs_embeds.shape[0] - 1][prediction_id] if attribution_flag: - saliency_scores = saliency(prediction_logit, token_ids_tensor_one_hot) + saliency_results = compute_saliency_scores(prediction_logit, token_ids_tensor_one_hot, inputs_embeds) + if 'gradient' not in self.attributions: self.attributions['gradient'] = [] - self.attributions['gradient'].append(saliency_scores.cpu().detach().numpy()) + self.attributions['gradient'].append(saliency_results['gradient'].cpu().detach().numpy()) - grad_x_input = gradient_x_inputs_attribution(prediction_logit, - inputs_embeds) if 'grad_x_input' not in self.attributions: self.attributions['grad_x_input'] = [] - self.attributions['grad_x_input'].append(grad_x_input.cpu().detach().numpy()) + self.attributions['grad_x_input'].append(saliency_results['grad_x_input'].cpu().detach().numpy()) del output.logits # free tensor memory we won't use again From d3a09b7625278e277161f8004449369db03f7617 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 15:04:59 -0800 Subject: [PATCH 21/26] try reversing order --- src/ecco/attribution.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/ecco/attribution.py b/src/ecco/attribution.py index 7efaa6c..63230de 100644 --- a/src/ecco/attribution.py +++ b/src/ecco/attribution.py @@ -80,12 +80,14 @@ def compute_saliency_scores(prediction_logit, ): results = {} - results['gradient'] = saliency(prediction_logit, - token_ids_tensor_one_hot, - retain_graph=True, - **gradient_kwargs) results['grad_x_input'] = gradient_x_inputs_attribution(prediction_logit, inputs_embeds, - retain_graph=False, + retain_graph=True, **gradient_x_input_kwargs) + + results['gradient'] = saliency(prediction_logit, + token_ids_tensor_one_hot, + retain_graph=False, + **gradient_kwargs) + return results From 1e957a4c1c9bd49c203993a23ada18b7f2cc0869 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 15:33:14 -0800 Subject: [PATCH 22/26] if not passing from_layer or to_layer, default to subset available, not all layers --- src/ecco/output.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ecco/output.py b/src/ecco/output.py index cbcae59..58c3012 100644 --- a/src/ecco/output.py +++ b/src/ecco/output.py @@ -523,11 +523,11 @@ def __init__(self, activations: np.ndarray, if from_layer > to_layer: raise ValueError(f"from_layer ({from_layer}) cannot be larger than to_layer ({to_layer}).") + + layer_nums = list(range(from_layer, to_layer)) else: - from_layer = 0 - to_layer = activations.shape[0] + layer_nums = sorted(layer_nums_to_row_ixs.keys()) - layer_nums = list(range(from_layer, to_layer)) if any([num not in layer_nums_to_row_ixs for num in layer_nums]): available = sorted(layer_nums_to_row_ixs.keys()) raise ValueError(f"Not all layers between from_layer ({from_layer}) and to_layer ({to_layer}) have recorded activations. Layers with recorded activations are: {available}") From d58cb46103a351ba54e79361fdb743bbd1bb0e63 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 16:36:48 -0800 Subject: [PATCH 23/26] clear actual value, not just convenience name exposed for it --- src/ecco/lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 5de82ce..5612a6a 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -126,7 +126,7 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, self.attributions['grad_x_input'] = [] self.attributions['grad_x_input'].append(saliency_results['grad_x_input'].cpu().detach().numpy()) - del output.logits # free tensor memory we won't use again + del output['logits'] # free tensor memory we won't use again # detach(): don't need grads here # cpu(): not used by GPU during generation; may lead to GPU OOM if left on GPU during long generations From 72eb173c317647ad444b59f8d2b2d87a3d660726 Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Wed, 23 Dec 2020 16:40:10 -0800 Subject: [PATCH 24/26] (2nd try) clear actual value, not just convenience name exposed for it --- src/ecco/lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 5612a6a..ff0162d 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -126,7 +126,7 @@ def _generate_token(self, input_ids, past, do_sample: bool, temperature: float, self.attributions['grad_x_input'] = [] self.attributions['grad_x_input'].append(saliency_results['grad_x_input'].cpu().detach().numpy()) - del output['logits'] # free tensor memory we won't use again + output['logits'] = None # free tensor memory we won't use again # detach(): don't need grads here # cpu(): not used by GPU during generation; may lead to GPU OOM if left on GPU during long generations From d4f57969404e9f04d4f504854403051df352e7cd Mon Sep 17 00:00:00 2001 From: Jay Alammar Date: Thu, 24 Dec 2020 15:35:58 +0300 Subject: [PATCH 25/26] New branch for 0.0.11 --- setup.py | 2 +- src/ecco/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index f3e80e3..e46dfc2 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def read(*names, **kwargs): setup( name='ecco', - version='0.0.10', + version='0.0.11', license='BSD-3-Clause', description='Visualization tools for NLP machine learning models.', long_description='%s\n%s' % ( diff --git a/src/ecco/__init__.py b/src/ecco/__init__.py index 046d1cd..a4e4ef5 100644 --- a/src/ecco/__init__.py +++ b/src/ecco/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.0.10' +__version__ = '0.0.11' from ecco.lm import LM, MockGPT, MockGPTTokenizer from transformers import AutoTokenizer, AutoModelForCausalLM From 90deb5caf1502401e3f8a1848b60f357618d7c9d Mon Sep 17 00:00:00 2001 From: nostalgebraist Date: Thu, 24 Dec 2020 10:51:58 -0800 Subject: [PATCH 26/26] correct shape in test_activations_dict_to_array input example --- tests/lm_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lm_test.py b/tests/lm_test.py index 4cab263..ec04bbe 100644 --- a/tests/lm_test.py +++ b/tests/lm_test.py @@ -27,8 +27,8 @@ def test_select_output_token_sample(self): assert result == torch.tensor(2) def test_activations_dict_to_array(self): - dict = {0:[[np.zeros((3,4))]], - 1:[[np.zeros((3,4))]]} + dict = {0:[np.zeros((3,4))], + 1:[np.zeros((3,4))]} activations = activations_dict_to_array(dict) assert activations.shape == (2,4,3)