Multi-Head Self-Attention

Explore how multi-head attention expands upon self-attention.

The idea of self-attention can be expanded to multi-head attention. In essence, we run through the attention mechanism several times.

Each time, we map the independent set of Key, Query, Value matrices into different lower-dimensional spaces and compute the attention there. The individual output is called a “head”. The mapping is achieved by multiplying each matrix with a separate weight matrix, which is denoted as WiK,WiQRdmodel×dk{W}_{i}^{K} , {W}_{i}^{Q} \in R^{d_{model} \times d_{k} } and WiVRdmodel×dk{W}_{i}^{V} \in R^{d_{model} \times d_{k}}, where ii is the head index.

To compensate for the extra complexity, the output vector size is divided by the number of heads. Specifically, in the vanilla transformer, they use dmodel=512d_{model}=512 and h=8h=8 heads, which gives us vector representations of dk=64d_k = 64.

With multi-head attention, the model has multiple independent paths (ways) to understand the input.

The heads are then concatenated and transformed using a square weight matrix WORdmodel×dmodel{W}^{O} \in R^{d_{model} \times d_{model}}, since dmodel=hdkd_{model}=h d_{k}.

Putting it all together, we get:

MultiHead(Q,K,V)=Concat(head1,,headh)WO MultiHead ({Q}, {K}, {V}) = { Concat (head }_{1}, \ldots, { head } \left._{{h}}\right) {W}^{O}

where headi=Attention(QWiQ,KWiK,VWiV)head_{{i}} = { Attention }\left({Q} {W}_{i}^{Q}, {K} {W}_{i}^{K},{V} {W}_{i}^{V}\right)

where again:

WiQ,WiK,WiVRdmodel×dk{W}_{i}^{Q}, {W}_{i}^{K}, {W}_{i}^{V} \in {R}^{d_{\text{model}} \times d_{k}}

Since heads are independent of each other, we can perform the self-attention computation in parallel on different workers:

Get hands-on with 1400+ tech skills courses.