diff --git a/textattack/goal_functions/goal_function.py b/textattack/goal_functions/goal_function.py index 16f49830..78693f67 100644 --- a/textattack/goal_functions/goal_function.py +++ b/textattack/goal_functions/goal_function.py @@ -176,13 +176,15 @@ def _call_model_uncached(self, attacked_text_list): if isinstance(batch_preds, list): outputs.extend(batch_preds) elif isinstance(batch_preds, np.ndarray): - outputs.append(torch.tensor(batch_preds)) + outputs.append(batch_preds) else: outputs.append(batch_preds) i += self.batch_size if isinstance(outputs[0], torch.Tensor): outputs = torch.cat(outputs, dim=0) + elif isinstance(outputs[0], np.ndarray): + outputs = np.concatenate(outputs).ravel() assert len(inputs) == len( outputs