Skip to content

Commit

Permalink
Merge pull request #635 from QData/hard-label-attack
Browse files Browse the repository at this point in the history
hard label classification
  • Loading branch information
qiyanjun authored Sep 11, 2023
2 parents cab4e0f + 87c4671 commit f848247
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
5 changes: 1 addition & 4 deletions textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
#
# Classification goal functions
#
"hardlabel-classification": "textattack.goal_functions.classification.HardLabelClassification",
"targeted-classification": "textattack.goal_functions.classification.TargetedClassification",
"untargeted-classification": "textattack.goal_functions.classification.UntargetedClassification",
"input-reduction": "textattack.goal_functions.classification.InputReduction",
Expand All @@ -127,15 +128,13 @@
@dataclass
class AttackArgs:
"""Attack arguments to be passed to :class:`~textattack.Attacker`.
Args:
num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
The number of examples to attack. :obj:`-1` for entire dataset.
num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
The number of successful adversarial examples we want. This is different from :obj:`num_examples`
as :obj:`num_examples` only cares about attacking `N` samples while :obj:`num_successful_examples` aims to keep attacking
until we have `N` successful cases.
.. note::
If set, this argument overrides `num_examples` argument.
num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
Expand All @@ -149,7 +148,6 @@ class AttackArgs:
query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
The maximum number of model queries allowed per example attacked.
If not set, we use the query budget set in the :class:`~textattack.goal_functions.GoalFunction` object (which by default is :obj:`float("inf")`).
.. note::
Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
Expand Down Expand Up @@ -468,7 +466,6 @@ def create_loggers_from_args(cls, args):
class _CommandLineAttackArgs:
"""Attack args for command line execution. This requires more arguments to
create ``Attack`` object as specified.
Args:
transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
Name of transformation to use.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Determine if an attack has been successful in Hard Label Classficiation.
----------------------------------------------------
"""


from .classification_goal_function import ClassificationGoalFunction


class HardLabelClassification(ClassificationGoalFunction):
"""An hard label attack on classification models which attempts to maximize
the semantic similarity of the label such that the target is outside of the
decision boundary.
Args:
target_max_score (float): If set, goal is to reduce model output to
below this score. Otherwise, goal is to change the overall predicted
class.
"""

def __init__(self, *args, target_max_score=None, **kwargs):
self.target_max_score = target_max_score
super().__init__(*args, **kwargs)

def _is_goal_complete(self, model_output, _):
if self.target_max_score:
return model_output[self.ground_truth_output] < self.target_max_score
elif (model_output.numel() == 1) and isinstance(
self.ground_truth_output, float
):
return abs(self.ground_truth_output - model_output.item()) >= 0.5
else:
return model_output.argmax() != self.ground_truth_output

def _get_score(self, model_output, _):
# If the model outputs a single number and the ground truth output is
# a float, we assume that this is a regression task.
if (model_output.numel() == 1) and isinstance(self.ground_truth_output, float):
return max(model_output.item(), self.ground_truth_output)
else:
return 1 - model_output[self.ground_truth_output]

0 comments on commit f848247

Please sign in to comment.