Skip to content

Commit

Permalink
whisper : use FA for cross-attention
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 14, 2024
1 parent 1789d3d commit 0935288
Showing 1 changed file with 62 additions and 21 deletions.
83 changes: 62 additions & 21 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2128,6 +2128,8 @@ static struct ggml_cgraph * whisper_build_graph_cross(

const int n_state_head = n_state/n_head;

const int n_ctx_pad = GGML_PAD(n_ctx, 256);

struct ggml_init_params params = {
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
Expand All @@ -2145,13 +2147,34 @@ static struct ggml_cgraph * whisper_build_graph_cross(
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
auto & layer = model.layers_decoder[il];

struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
layer.cross_attn_k_w,
cur);

Kcross = ggml_scale(ctx0, Kcross, Kscale);

struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
#ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w,
cur);

Vcross = ggml_add(ctx0,
Vcross,
layer.cross_attn_v_b);

struct ggml_tensor * k = ggml_view_3d(ctx0, wstate.kv_cross.k,
n_state_head, n_ctx, n_head,
(ggml_element_size(wstate.kv_cross.k)*n_state_head),
(ggml_element_size(wstate.kv_cross.k)*n_state_head*n_ctx_pad),
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));

struct ggml_tensor * v = ggml_view_3d(ctx0, wstate.kv_cross.v,
n_state_head, n_ctx, n_head,
(ggml_element_size(wstate.kv_cross.v)*n_state_head),
(ggml_element_size(wstate.kv_cross.v)*n_state_head*n_ctx_pad),
(ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
#else
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w,
cur);

Expand All @@ -2168,6 +2191,7 @@ static struct ggml_cgraph * whisper_build_graph_cross(
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
#endif

ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
Expand Down Expand Up @@ -2311,6 +2335,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
const int n_tokens = batch.n_tokens;
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;

const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);

const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;

Expand Down Expand Up @@ -2477,27 +2503,46 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
Qcur,
layer.cross_attn_q_b);

Qcur = ggml_scale(ctx0, Qcur, KQscale);
#ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
0, 2, 1, 3);

// Kcross is already scaled
struct ggml_tensor * Kcross =
ggml_view_3d(ctx0, wstate.kv_cross.k,
n_state_head, n_audio_ctx, n_head,
n_state_head, n_audio_ctx_pad, n_head,
ggml_element_size(wstate.kv_cross.k)*n_state,
ggml_element_size(wstate.kv_cross.k)*n_state_head,
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);

//struct ggml_tensor * Vcross =
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
// n_state_head, n_head, n_audio_ctx);
struct ggml_tensor * Vcross =
ggml_view_3d(ctx0, wstate.kv_cross.v,
n_state_head, n_audio_ctx_pad, n_head,
ggml_element_size(wstate.kv_cross.v)*n_state,
ggml_element_size(wstate.kv_cross.v)*n_state_head,
ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);

//struct ggml_tensor * V_trans =
// ggml_cpy(ctx0,
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state_head, n_head));
cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f);

struct ggml_tensor * V =
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
#else
Qcur = ggml_scale(ctx0, Qcur, KQscale);

struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
0, 2, 1, 3);

struct ggml_tensor * Kcross =
ggml_view_3d(ctx0, wstate.kv_cross.k,
n_state_head, n_audio_ctx, n_head,
ggml_element_size(wstate.kv_cross.k)*n_state,
ggml_element_size(wstate.kv_cross.k)*n_state_head,
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);

struct ggml_tensor * Vcross =
ggml_view_3d(ctx0, wstate.kv_cross.v,
n_audio_ctx, n_state_head, n_head,
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
Expand All @@ -2506,11 +2551,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder(

// ------

struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
0, 2, 1, 3);

// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);

Expand Down Expand Up @@ -2543,14 +2583,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
}
}

struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);

struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);

// cur = KQV_merged.contiguous().view(n_state, n_tokens)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
#endif
}

// projection
Expand Down Expand Up @@ -3210,7 +3251,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype,
ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer,
ctx->model.hparams.n_audio_ctx)) {
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
whisper_free_state(state);
return nullptr;
Expand Down

0 comments on commit 0935288

Please sign in to comment.