-
Notifications
You must be signed in to change notification settings - Fork 5
/
ort-t5.html
executable file
·307 lines (267 loc) · 9.19 KB
/
ort-t5.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
<html>
<script type="module">
import { AutoTokenizer } from './transformers.js';
//import { AutoTokenizer } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers/dist/transformers.js'
window.AutoTokenizer = AutoTokenizer;
</script>
<script src="dist/ort.webgpu.min.js"> </script>
<body>
<h1 id="h1-title">t5 end to end test</h1>
<a href="ort-t5.html?provider=webgpu">webgpu</a><br />
<a href="ort-t5.html?provider=webgpu&io_binding=1">webgpu+io_binding</a><br />
<a href="ort-t5.html?provider=wasm">wasm</a><br />
<a href="ort-t5.html?provider=wasm&io_binding=1">wasm+io_binding</a><br />
<p class="text-lg-start">
<div id="status" style="font: 1em consolas;"></div>
</p>
</body>
<script>
var cons_out = [];
ort.env.wasm.numThreads = 1;
function getConfig() {
const query = window.location.search.substring(1);
var config = {
provider: "webgpu",
profiler: "0",
verbose: "0",
io_binding: "0",
};
let vars = query.split("&");
for (var i = 0; i < vars.length; i++) {
let pair = vars[i].split("=");
if (pair[0] in config) {
config[pair[0]] = decodeURIComponent(pair[1]);
} else if (pair[0].length > 0) {
throw new Error("unknown argument: " + pair[0]);
}
}
config.verbose = parseInt(config.verbose);
config.profile = parseInt(config.profile);
config.io_binding = (config.io_binding == "1");
config.profiler = parseInt(config.profiler);
return config;
}
function log(i) {
const msg = `${performance.now().toFixed(3)}| ${i}`;
document.getElementById('status').innerText += `\n${msg}`;
console.log(msg);
}
function ones(shape) {
let size = 1;
const val = 1n;
shape.forEach(element => {
size *= element;
});
return new ort.Tensor(BigInt64Array.from({ length: size }, () => val), shape);
}
function max(arr) {
let max = arr[0];
let indexOfMax = 0;
for (let i = 1; i < arr.length; ++i) {
if (arr[i] > max) {
max = arr[i];
indexOfMax = i;
}
}
return [max, indexOfMax];
}
async function fetchAndCache(url) {
try {
const cache = await caches.open("onnx");
let cachedResponse = await cache.match(url);
if (cachedResponse == undefined) {
await cache.add(url);
cachedResponse = await cache.match(url);
log(`${url} (from network)`);
} else {
log(`${url} (cached)`);
}
const data = await cachedResponse.arrayBuffer();
return data;
} catch (error) {
log(`${url} (from network)`);
return await fetch(url).then(response => response.arrayBuffer());
}
}
function copy_kv_cache(feed, output) {
for (const [name, t] of Object.entries(output)) {
if (name.startsWith('present')) {
let name_in_feed = name.replace('present', 'past_key_values');
if (t.dims[0] == 0) {
// Optimization introduced by optimum to reuse past key values.
// https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
t.dispose();
} else {
feed[name_in_feed].dispose();
feed[name_in_feed] = t;
}
}
}
}
function redirect_log() {
console.log = function (message) {
cons_out.push(message);
};
}
const config = getConfig();
document.getElementById('h1-title').innerText = `t5-small ${config.provider} io_binding=${config.io_binding}`;
if (config.profiler) {
redirect_log();
}
function post_process(tokenizer, output, output_ids) {
let logits = output.logits.data.slice(null, -1, null)
let argmax = max(logits)[1];
output_ids.push(argmax);
const txt = tokenizer.decode_single(output_ids, { skip_special_tokens: true });
log(txt);
return BigInt(argmax);
}
function dump_feed(feed) {
for (const [name, t] of Object.entries(feed)) {
if (t.dataLocation == 'gpu-buffer')
continue;
log(`${name} ${t.dataLocation} ${t.dims}.${t.type} ${t.size}`);
}
log("----");
}
function hack_encoder_kv(feed) {
const dev = ort.env.webgpu.device;
for (const [name, t] of Object.entries(feed)) {
if (! name.includes("encoder."))
continue;
const buf = dev.createBuffer({ mappedAtCreation: true, size: input.byteLength, usage: GPUBufferUsage.STORAGE });
const arr = buf.getMappedRange();
new Float32Array(arr).set(input);
buf.unmap();
feed[name] = t;
}
}
async function run() {
try {
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');
const text = 'translate English to French: Hello, how are you?';
let { input_ids } = await tokenizer(text);
const encoder_attention_mask = ones(input_ids.dims);
const kv_names = [
'.0.decoder.key',
'.0.decoder.value',
'.1.decoder.key',
'.1.decoder.value',
'.2.decoder.key',
'.2.decoder.value',
'.3.decoder.key',
'.3.decoder.value',
'.4.decoder.key',
'.4.decoder.value',
'.5.decoder.key',
'.5.decoder.value',
'.0.encoder.key',
'.0.encoder.value',
'.1.encoder.key',
'.1.encoder.value',
'.2.encoder.key',
'.2.encoder.value',
'.3.encoder.key',
'.3.encoder.value',
'.4.encoder.key',
'.4.encoder.value',
'.5.encoder.key',
'.5.encoder.value',
];
const eopt = {
executionProviders: [config.provider]
};
const dopt = {
executionProviders: [config.provider]
};
if (config.verbose) {
ort.env.debug = true;
ort.env.logLevel = 'verbose';
dopt.logSeverityLevel = 0;
dopt.logVerbosityLevel = 0;
}
if (config.io_binding) {
eopt.preferredOutputLocation = {'last_hidden_state': "gpu-buffer"};
}
const encoder = await ort.InferenceSession.create(await fetchAndCache("models/tjs/t5-small/onnx/encoder_model.onnx"), eopt);
if (config.io_binding && config.provider == "webgpu") {
dopt.preferredOutputLocation = {};
kv_names.forEach(name => {
dopt.preferredOutputLocation[`present${name}`] = "gpu-buffer";
});
dopt.preferredOutputLocation['last_hidden_state'] = "gpu-buffer";
}
if (config.profiler) {
dopt.enableProfiling = true;
ort.env.logLevel = "verbose";
if (config.provider == "webgpu") {
ort.env.webgpu.profilingMode = 'default';
}
}
// dopt.optimizedModelFilePath = 'opt.onnx';
const decoder = await ort.InferenceSession.create(await fetchAndCache("models/tjs/t5-small/onnx/decoder_model_merged.onnx"), dopt);
log("running encoder");
const efeed = { input_ids: input_ids, attention_mask: encoder_attention_mask };
const eres = await encoder.run(efeed);
const hidden_state = eres.last_hidden_state;
let seqlen = 0;
let decoder_shape = [1, 8, seqlen, 64];
let encoder_shape = [1, 8, seqlen, 64];
let last_token = 0n;
let output_ids = new Array();
var dfeed = {};
let timeing = [];
kv_names.forEach(name => {
dfeed["past_key_values" + name] = new ort.Tensor('float32', [], [1, 8, 0, 64]);
});
const input_ids_tensor = new ort.Tensor(new BigInt64Array(1), [1, 1]);
const use_cache_branch_tensor = new ort.Tensor("bool", [false], [1]);
dfeed.input_ids = input_ids_tensor; // cpu tensor scalar
dfeed.use_cache_branch = use_cache_branch_tensor // cpu tensor scalar
dfeed.encoder_hidden_states = hidden_state; // (IO-Binding=ON) gpu tensor (output from encoder)
dfeed.encoder_attention_mask = encoder_attention_mask; // cpu tensor (const)
if (config.profiler) {
redirect_log();
}
log("running decoder");
do {
input_ids_tensor.data[0] = last_token;
use_cache_branch_tensor.data[0] = seqlen > 0;
// dump_feed(dfeed);
const start = performance.now();
let output = await decoder.run(dfeed);
const stop = performance.now();
timeing.push(stop - start);
last_token = post_process(tokenizer, output, output_ids);
copy_kv_cache(dfeed, output);
seqlen++;
} while (last_token != 1n);
// ignore 1st decoder run
const first = timeing.shift();
log(`Took: first=${first.toFixed(2)}ms, avg=${(timeing.reduce((a, b) => a + b, 0) / timeing.length).toFixed(2)}ms, tokens=${seqlen}`);
console.log(timeing);
// release tensors
for (const [name, t] of Object.entries(dfeed)) {
t.dispose();
}
if (config.profiler) {
await decoder.endProfiling();
}
if (cons_out.length > 0) {
let base64 = btoa(cons_out.join('\n'));
let link = document.createElement("a");
link.href = `data:application/json;base64,${base64}`;
link.download = "profiler.log";
link.innerText = "Download";
document.body.appendChild(link);
}
encoder.release();
decoder.release();
} catch (e) {
console.log(e);
log(e.message);
}
}
document.addEventListener("DOMContentLoaded", () => { run(); });
</script>
</html>