Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pointer-generator T5 models in BeamSearch #23134

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 45 additions & 20 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@

Inputs:
input_ids: int32 (B, 1)
encoder_input_ids: int32 (B, encode_sequence_length) (optional)
encoder_attention_mask: int32 (B, encode_sequence_length)
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional)

past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
Expand Down Expand Up @@ -49,11 +50,9 @@

Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
bool has_hidden_state = subgraph_inputs[2]->Name() == "encoder_hidden_states" ? true : false;
SetPastInputIndex(has_hidden_state);

ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3,
"kFirstPastInputIndex currently only supports 2 or 3");
bool has_encoder_input_ids = subgraph_inputs[1]->Name() == "encoder_input_ids";
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name() == "encoder_hidden_states";
Dismissed Show dismissed Hide dismissed
SetPastInputIndex(has_hidden_state, has_encoder_input_ids);

if (!past_present_share_buffer_) {
ORT_RETURN_IF(has_decoder_masked_attention_, "decoder_masked_attention shall use with past_present_share_buffer");
Expand All @@ -75,13 +74,17 @@

ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
"decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
"decoder subgraph input 1 shall be named as encoder_attention_mask, got: ",
subgraph_inputs[1]->Name());
if (first_past_input_index_ == 3) {
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "encoder_hidden_states",
"decoder subgraph input 2 shall be named as encoder_hidden_states, got: ",
subgraph_inputs[2]->Name());
const int enc_attn_mask_index = 1 + has_encoder_input_ids_;
const int enc_hidden_state_index = enc_attn_mask_index + 1;
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->Name() != "encoder_attention_mask",
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
" shall be named as encoder_attention_mask, got: ",
subgraph_inputs[enc_attn_mask_index]->Name());
if (has_hidden_state_) {
ORT_RETURN_IF(subgraph_inputs[enc_hidden_state_index]->Name() != "encoder_hidden_states",
"decoder subgraph input ", std::to_string(enc_hidden_state_index),
" shall be named as encoder_hidden_states, got: ",
subgraph_inputs[enc_hidden_state_index]->Name());
}

// check subgraph outputs
Expand All @@ -108,12 +111,19 @@

ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 0 (input_ids) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 1 (encoder_attention_mask) shall have int32 type");

auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
"decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type");
if (has_encoder_input_ids_) {
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 1 (encoder_input_ids) shall have int32 type");
}
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
" (encoder_attention_mask) shall have int32 type");

auto float_type = subgraph_inputs[enc_hidden_state_index]->TypeAsProto()->tensor_type().elem_type();
if (has_hidden_state_) {
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
"decoder subgraph input ", std::to_string(enc_hidden_state_index), " (encoder_hidden_states) shall have float or float16 type");
}

for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) {
ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
Expand Down Expand Up @@ -219,6 +229,19 @@
decoder_feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
decoder_feeds.push_back(input_ids);

if (has_encoder_input_ids_) {
// The encoder_input_ids is copied from the first input of encoder.
OrtValue expanded_encoder_input_ids;
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
encoder_feeds[0],
num_beam,
allocator,
expanded_encoder_input_ids,
false,
0 /*max_sequence_length*/));
decoder_feeds.push_back(expanded_encoder_input_ids);
}

// The encoder_attention_mask is copied from the second input of encoder.
OrtValue expanded_decoder_attention_masks;
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
Expand All @@ -238,7 +261,9 @@
// When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
// of encoder.
// When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
for (size_t j = static_cast<size_t>(4) - first_past_input_index_; j < encoder_fetches.size(); j++) {
// TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions.

Check warning on line 264 in onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc:264: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds?
for (size_t j = static_cast<size_t>(2) - has_hidden_state_; j < encoder_fetches.size(); j++) {
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
if (j == 1) {
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
OrtValue expanded_hidden_states;
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,10 @@ class T5DecoderSubgraph : public Subgraph {
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) override;

void SetPastInputIndex(bool has_hidden_state) {
void SetPastInputIndex(bool has_hidden_state, bool has_encoder_input_ids) {
has_hidden_state_ = has_hidden_state;
if (!has_hidden_state_) {
first_past_input_index_ = 2;
} else {
first_past_input_index_ = 3;
}
has_encoder_input_ids_ = has_encoder_input_ids;
first_past_input_index_ = 2 + has_hidden_state_ + has_encoder_input_ids_;
}

int GetFirstPastInputIndex() const {
Expand All @@ -79,6 +76,7 @@ class T5DecoderSubgraph : public Subgraph {
int first_past_input_index_;
int first_present_output_index_;
bool has_hidden_state_;
bool has_encoder_input_ids_;
bool use_sequence_as_input_ids_;
};

Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/test/contrib_ops/beam_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ TEST(BeamSearchTest, DummyT5) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5.onnx
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
Expand All @@ -414,6 +416,8 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5_with_outer_scope_initializers.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_outer_scope_initializers.onnx --move-initializers
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
Expand All @@ -428,6 +432,8 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5_with_sequence_input_ids.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_sequence_input_ids.onnx --sequence-as-input
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8});
Expand All @@ -438,5 +444,21 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
tester.RunWithConfig();
}

TEST(BeamSearchTest, DummyT5PointerGenerator) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5_pointer_generator.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_pointer_generator.onnx --decoder-needs-input-ids
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_pointer_generator.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
tester.AddOutput("sequences", {1, 3, 10}, {2, 3, 6, 7, 3, 6, 7, 18, 3, 6, 2, 3, 6, 7, 18, 3, 6, 7, 18, 3, 2, 3, 6, 7, 3, 6, 7, 3, 6, 7});
#ifdef USE_CUDA
tester.ConfigEp(DefaultCudaExecutionProvider());
#endif
tester.RunWithConfig();
}

} // namespace test
} // namespace onnxruntime
Loading
Loading