diff --git a/benchmark/python/benchmark_e2e.py b/benchmark/python/benchmark_e2e.py index c759a0f39..da22b13c7 100644 --- a/benchmark/python/benchmark_e2e.py +++ b/benchmark/python/benchmark_e2e.py @@ -83,6 +83,14 @@ def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str: generator.generate_next_token() return tokenizer.decode(generator.get_sequence(0)) +# Use prompt length to get pre-defined prompt +def get_prompt_by_length(prompt_length): + json_path = "prompts.json" + with open(json_path) as prompts_file: + content = prompts_file.read() + data = json.load(content) + return data[f"{prompt_length}"] + def get_target_pip_package_version(target_pip_package_name_list): # get package name and version import pkg_resources @@ -232,6 +240,9 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length # use random tokens instead of generating a prompt using the model and then tokenizing it tokens = np.random.randint(100, size=(batch_size, prompt_length)) prompt = [tokenizer.decode(tokens[0])] * batch_size + elif args.use_prompt_set: + prompt = get_prompt_by_length(prompt_length) + tokens = tokenizer.encode_batch(prompt) else: prompt = [generate_prompt(model, tokenizer, prompt_length, args.use_graph_capture)] * batch_size tokens = tokenizer.encode_batch(prompt) @@ -424,6 +435,7 @@ def str2strlist(value): parser.add_argument('-mn', '--model_name', type=str, default='model_name', help='Model name defined by users') parser.add_argument('-pr', '--precision', type=str, default='fp16', help='Model precision for metrics info') parser.add_argument('--use_random_tokens', action='store_true', help='Use random tokens instead of generating a prompt') + parser.add_argument('--use_prompt_set', action='store_true', help='Use pre-generated prompt set instead of generating a prompt') args = parser.parse_args() # check max_lengths diff --git a/benchmark/python/prompts.json b/benchmark/python/prompts.json new file mode 100644 index 000000000..5c8334fa4 --- /dev/null +++ b/benchmark/python/prompts.json @@ -0,0 +1,7 @@ +{ + "16": "How are astronauts launched into space quickly on those rockets? ", + "64": "", + "256": "", + "1024": "", + "2048": "" +} \ No newline at end of file