Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hard label classification #635

Merged
merged 2 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
#
# Classification goal functions
#
"hardlabel-classification": "textattack.goal_functions.classification.HardLabelClassification",
"targeted-classification": "textattack.goal_functions.classification.TargetedClassification",
"untargeted-classification": "textattack.goal_functions.classification.UntargetedClassification",
"input-reduction": "textattack.goal_functions.classification.InputReduction",
Expand All @@ -126,15 +127,13 @@
@dataclass
class AttackArgs:
"""Attack arguments to be passed to :class:`~textattack.Attacker`.

Args:
num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
The number of examples to attack. :obj:`-1` for entire dataset.
num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
The number of successful adversarial examples we want. This is different from :obj:`num_examples`
as :obj:`num_examples` only cares about attacking `N` samples while :obj:`num_successful_examples` aims to keep attacking
until we have `N` successful cases.

.. note::
If set, this argument overrides `num_examples` argument.
num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
Expand All @@ -148,7 +147,6 @@ class AttackArgs:
query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
The maximum number of model queries allowed per example attacked.
If not set, we use the query budget set in the :class:`~textattack.goal_functions.GoalFunction` object (which by default is :obj:`float("inf")`).

.. note::
Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
Expand Down Expand Up @@ -439,7 +437,6 @@ def create_loggers_from_args(cls, args):
class _CommandLineAttackArgs:
"""Attack args for command line execution. This requires more arguments to
create ``Attack`` object as specified.

Args:
transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
Name of transformation to use.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Determine if an attack has been successful in Hard Label Classficiation.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Classification

----------------------------------------------------
"""


from .classification_goal_function import ClassificationGoalFunction


class HardLabelClassification(ClassificationGoalFunction):
"""An hard label attack on classification models which attempts to maximize
the semantic similarity of the label such that the target is outside of the
decision boundary.

Args:
target_max_score (float): If set, goal is to reduce model output to
below this score. Otherwise, goal is to change the overall predicted
class.
"""

def __init__(self, *args, target_max_score=None, **kwargs):
self.target_max_score = target_max_score
super().__init__(*args, **kwargs)

def _is_goal_complete(self, model_output, _):
if self.target_max_score:
return model_output[self.ground_truth_output] < self.target_max_score
elif (model_output.numel() == 1) and isinstance(
self.ground_truth_output, float
):
return abs(self.ground_truth_output - model_output.item()) >= 0.5
else:
return model_output.argmax() != self.ground_truth_output

def _get_score(self, model_output, _):
# If the model outputs a single number and the ground truth output is
# a float, we assume that this is a regression task.
if (model_output.numel() == 1) and isinstance(self.ground_truth_output, float):
return max(model_output.item(), self.ground_truth_output)
else:
return 1 - model_output[self.ground_truth_output]
Loading