-
Notifications
You must be signed in to change notification settings - Fork 0
/
configuration_t5.py
278 lines (247 loc) · 11.6 KB
/
configuration_t5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# coding=utf-8
# Copyright 2020, The T5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" T5 model configuration """
from collections import OrderedDict
from typing import Any, Dict, Iterable, Mapping, Optional
from transformers import PreTrainedTokenizer, TensorType
from transformers import is_torch_available
from transformers.configuration_utils import PretrainedConfig
from transformers.onnx import OnnxConfigWithPast
from transformers.utils import logging
logger = logging.get_logger(__name__)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"t5-small": "https://huggingface.co/t5-small/resolve/main/config.json",
"t5-base": "https://huggingface.co/t5-base/resolve/main/config.json",
"t5-large": "https://huggingface.co/t5-large/resolve/main/config.json",
"t5-3b": "https://huggingface.co/t5-3b/resolve/main/config.json",
"t5-11b": "https://huggingface.co/t5-11b/resolve/main/config.json",
}
class T5Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.T5Model` or a
:class:`~transformers.TFT5Model`. It is used to instantiate a T5 model according to the specified arguments,
defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
to that of the T5 `t5-small <https://huggingface.co/t5-small>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Arguments:
vocab_size (:obj:`int`, `optional`, defaults to 32128):
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`.
d_model (:obj:`int`, `optional`, defaults to 512):
Size of the encoder layers and the pooler layer.
d_kv (:obj:`int`, `optional`, defaults to 64):
Size of the key, query, value projections per attention head. :obj:`d_kv` has to be equal to :obj:`d_model
// num_heads`.
d_ff (:obj:`int`, `optional`, defaults to 2048):
Size of the intermediate feed forward layer in each :obj:`T5Block`.
num_layers (:obj:`int`, `optional`, defaults to 6):
Number of hidden layers in the Transformer encoder.
num_decoder_layers (:obj:`int`, `optional`):
Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not
set.
num_heads (:obj:`int`, `optional`, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
relative_attention_num_buckets (:obj:`int`, `optional`, defaults to 32):
The number of buckets to use for each attention layer.
dropout_rate (:obj:`float`, `optional`, defaults to 0.1):
The ratio for all dropout layers.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-6):
The epsilon used by the layer normalization layers.
initializer_factor (:obj:`float`, `optional`, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"relu"`):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. T5v1.1 uses
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
"""
model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32128,
d_model=512,
d_kv=64,
d_ff=2048,
num_layers=6,
num_decoder_layers=None,
num_heads=8,
relative_attention_num_buckets=32,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
feed_forward_proj="relu",
is_encoder_decoder=True,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
gradient_checkpointing=False,
apply_lora=False,
lora_alpha=None,
lora_r=None,
apply_adapter=False,
adapter_type=None,
adapter_size=None,
apply_lora_BR=False,
apply_bias=False,
apply_bias_stage2=False,
decoder_mlp=False,
share_lora_R=False,
share_intrinsic=False,
intrinsic_dim=None,
apply_prefix=False,
prefix_num=24,
prefix_r=120,
r_mean=0,
r_std=0.02,
lora_uniform=5,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # default = symmetry
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
self.gradient_checkpointing = gradient_checkpointing
self.apply_lora = apply_lora
self.lora_alpha = lora_alpha
self.lora_r = lora_r
self.apply_adapter = apply_adapter
self.adapter_type = adapter_type
self.adapter_size = adapter_size
self.apply_lora_BR = apply_lora_BR
self.apply_bias = apply_bias
self.apply_bias_stage2 = apply_bias_stage2
self.decoder_mlp = decoder_mlp
self.share_lora_R = share_lora_R
self.share_intrinsic = share_intrinsic
self.intrinsic_dim = intrinsic_dim
self.apply_prefix = apply_prefix
self.prefix_num = prefix_num
self.prefix_r = prefix_r
self.r_mean = r_mean
self.r_std = r_std
self.lora_uniform=lora_uniform
@property
def hidden_size(self):
return self.d_model
@property
def num_attention_heads(self):
return self.num_heads
@property
def num_hidden_layers(self):
return self.num_layers
class T5OnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch"}),
("decoder_attention_mask", {0: "batch"}),
]
)
if self.use_past:
for i in range(0, self._config.num_layers):
common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"}
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
if "last_hidden_state" in common_outputs:
common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
for i in range(self._config.num_layers):
common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"}
common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"}
if self.task == "default":
common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"}
return common_outputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
# Generate encoder inputs
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
# Generate decoder inputs
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
ordered_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch = encoder_inputs["input_ids"].shape[0]
encoder_seq_length = encoder_inputs["input_ids"].shape[1]
encoder_shape = (
batch,
self._config.num_heads,
encoder_seq_length,
self._config.hidden_size // self._config.num_heads,
)
decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
ordered_inputs["past_key_values"] = []
for _ in range(self._config.num_layers):
ordered_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
return ordered_inputs
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.decoder.key"] = t[0]
flatten_output[f"{name}.{idx}.decoder.value"] = t[1]
flatten_output[f"{name}.{idx}.encoder.key"] = t[2]
flatten_output[f"{name}.{idx}.encoder.value"] = t[3]
return flatten_output
return super().flatten_output_collection_property(name, field)