10.5. Çoklu-Kafalı Dikkat
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in SageMaker Studio Lab

Pratikte, aynı sorgular, anahtarlar ve değerler kümesi verildiğinde, modelimizin, çeşitli aralıkların bir sıra içinde (örneğin, daha kısa menzile karşı daha uzun menzil) bağımlılıklarını yakalama gibi aynı dikkat mekanizmasının farklı davranışlarından elde edilen bilgileri birleştirmesini isteyebiliriz. Bu nedenle, dikkat mekanizmamızın sorguların, anahtarların ve değerlerin farklı temsil alt alanlarını ortaklaşa kullanmasına izin vermek yararlı olabilir.

Bu amaçla, tek bir dikkat ortaklaması yerine, sorgular, anahtarlar ve değerler \(h\) tane bağımsız olarak öğrenilen doğrusal izdüşümler ile dönüştürülebilir. Daha sonra bu \(h\) öngörülen sorgular, anahtarlar ve değerler paralel olarak dikkat ortaklaması içine beslenir. Nihayetinde, \(h\) dikkat ortaklama çıktıları bitiştirilir ve son çıktıyı üretmek için başka bir öğrenilmiş doğrusal izdüşüm ile dönüştürülür. Bu tasarıma çoklu kafalı dikkat denir, burada \(h\) dikkat ortaklama çıktılarının her biri kafadır (Vaswani et al., 2017). Öğrenilebilir doğrusal dönüşümler gerçekleştirmek için tam bağlı katmanları kullanan çoklu kafalı dikkat Fig. 10.5.1 şeklinde açıklanmıştır.

../_images/multi-head-attention.svg

Fig. 10.5.1 Çoklu kafanın bir araya getirildiği ve ardından doğrusal olarak dönüştürüldüğü çoklu kafalı dikkat.

10.5.1. Model

Çoklu kafalı dikkatin uygulanmasını sağlamadan önce, bu modeli matematiksel olarak biçimlendirelim. Bir sorgu \(\mathbf{q} \in \mathbb{R}^{d_q}\), bir anahtar \(\mathbf{k} \in \mathbb{R}^{d_k}\) ve bir değer \(\mathbf{v} \in \mathbb{R}^{d_v}\) göz önüne alındığında, her dikkat kafası \(\mathbf{h}_i\) (\(i = 1, \ldots, h\)) aşağıdaki gibi hesaplanır

(10.5.1)\[\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},\]

burada öğrenilebilir parametreler \(\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}\), \(\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}\) ve \(\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}\) ve \(f\), Section 10.3 içindeki toplayıcı dikkat ve ölçeklendirilmiş nokta çarpımı dikkat gibi dikkat ortaklamasıdır. Çoklu kafalı dikkat çıktısı, \(h\) kafalarının bitiştirilmesinin \(\mathbf W_o\in\mathbb R^{p_o\times h p_v}\) öğrenilebilir parametreleri vasıtasıyla başka bir doğrusal dönüşümdür:

(10.5.2)\[\begin{split}\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.\end{split}\]

Bu tasarıma dayanarak, her kafa girdisinin farklı bölümleriyle ilgilenebilir. Basit ağırlıklı ortalamadan daha gelişmiş fonksiyonlar ifade edilebilir.

import math
from d2l import mxnet as d2l
from mxnet import autograd, np, npx
from mxnet.gluon import nn

npx.set_np()
import math
import torch
from torch import nn
from d2l import torch as d2l
import tensorflow as tf
from d2l import tensorflow as d2l

10.5.2. Uygulama

Uygulamamızda, çoklu kafalı dikkatin her bir kafa için ölçeklendirilmiş nokta-çarpımı dikkatini seçiyoruz. Hesaplama maliyetinde ve parametreleştirme maliyetinde önemli bir artıştan kaçınmak için \(p_q = p_k = p_v = p_o / h\) olarak ayarladık. Sorgu, anahtar ve değer için doğrusal dönüşümlerin çıktı sayısını \(p_q h = p_k h = p_v h = p_o\) olarak ayarlarsak, \(h\) adet kafanın paralel olarak hesaplanabileceğini unutmayın. Aşağıdaki uygulamada, \(p_o\), num_hiddens bağımsız değişkeni aracılığıyla belirtilir.

#@save
class MultiHeadAttention(nn.Block):
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
                 **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)

    def forward(self, queries, keys, values, valid_lens):
        # `queries`, `keys`, veya `values` şekli:
        # (`batch_size`,  anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`)
        # `valid_lens`'in şekli:
        # (`batch_size`,) or (`batch_size`, no. of queries)
        # Devirme sonrası, output `queries`, `keys`, veya `values` şekli:
        # (`batch_size` * `num_heads`,  anahtar-değer çiftleri veya sorgu sayısı,
        # `num_hiddens` / `num_heads`)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 0 ekseninde, ilk öğeyi (skaler veya vektör) `num_heads` kez
            # kopyalayın, ardından sonraki öğeyi kopyalayın ve devam edin.
            valid_lens = valid_lens.repeat(self.num_heads, axis=0)

        # `output`'un şekli: (`batch_size` * `num_heads`, no. of queries,
        # `num_hiddens` / `num_heads`)
        output = self.attention(queries, keys, values, valid_lens)

        # `output_concat`'in şekli:
        # (`batch_size`, sorgu sayısı, `num_hiddens`)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)
#@save
class MultiHeadAttention(nn.Module):
    """Multi-head attention."""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # `queries`, `keys`, veya `values` şekli:
        # (`batch_size`,  anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`)
        # `valid_lens`'in şekli:
        # (`batch_size`,) or (`batch_size`, no. of queries)
        # Devirme sonrası, output `queries`, `keys`, veya `values` şekli:
        # (`batch_size` * `num_heads`,  anahtar-değer çiftleri veya sorgu sayısı,
        # `num_hiddens` / `num_heads`)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 0 ekseninde, ilk öğeyi (skaler veya vektör) `num_heads` kez
            # kopyalayın, ardından sonraki öğeyi kopyalayın ve devam edin.
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # `output`'un şekli: (`batch_size` * `num_heads`, no. of queries,
        # `num_hiddens` / `num_heads`)
        output = self.attention(queries, keys, values, valid_lens)

        # `output_concat`'in şekli:
        # (`batch_size`, sorgu sayısı, `num_hiddens`)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)
#@save
class MultiHeadAttention(tf.keras.layers.Layer):
    """Multi-head attention."""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
        self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
        self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
        self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias)

    def call(self, queries, keys, values, valid_lens, **kwargs):
        # `queries`, `keys`, veya `values` şekli:
        # (`batch_size`,  anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`)
        # `valid_lens`'in şekli:
        # (`batch_size`,) or (`batch_size`, no. of queries)
        # Devirme sonrası, output `queries`, `keys`, veya `values` şekli:
        # (`batch_size` * `num_heads`,  anahtar-değer çiftleri veya sorgu sayısı,
        # `num_hiddens` / `num_heads`)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 0 ekseninde, ilk öğeyi (skaler veya vektör) `num_heads` kez
            # kopyalayın, ardından sonraki öğeyi kopyalayın ve devam edin.
            valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0)

        # `output`'un şekli: (`batch_size` * `num_heads`, no. of queries,
        # `num_hiddens` / `num_heads`)
        output = self.attention(queries, keys, values, valid_lens, **kwargs)

        # `output_concat`'in şekli: (`batch_size`, sorgu sayısı, `num_hiddens`)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

Çoklu kafanın paralel hesaplanmasına izin vermek için, yukarıdaki MultiHeadAttention sınıfı aşağıda tanımlandığı gibi iki devrinim işlevi kullanır. Özellikle, transpose_output işlevi transpose_qkv işlevinin çalışmasını tersine çevirir.

#@save
def transpose_qkv(X, num_heads):
    """Çoklu dikkat kafasının paralel hesaplaması için aktarım."""
    # `X` girdisinin şekli:
    # (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`).
    # `X` çıktısının şekli:
    # (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_heads`,
    # `num_hiddens` / `num_heads`)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # `X` çıktısının şekli:
    # (`batch_size`, `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
    # `num_hiddens` / `num_heads`)
    X = X.transpose(0, 2, 1, 3)

    # `output`'un şekli:
    # (`batch_size` * `num_heads`, anahtar-değer çiftleri veya sorgu sayısı
    # `num_hiddens` / `num_heads`)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """`transpose_qkv` işlemini tersine çevir."""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.transpose(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)
#@save
def transpose_qkv(X, num_heads):
    """Çoklu dikkat kafasının paralel hesaplaması için aktarım."""
    # `X` girdisinin şekli:
    # (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`).
    # `X` çıktısının şekli:
    # (`batch_size`,anahtar-değer çiftleri veya sorgu sayısı, `num_heads`,
    # `num_hiddens` / `num_heads`)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # `X` çıktısının şekli:
    # (`batch_size`, `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
    # `num_hiddens` / `num_heads`)
    X = X.permute(0, 2, 1, 3)

    # `output`'un şekli:
    # (`batch_size` * `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
    # `num_hiddens` / `num_heads`)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """`transpose_qkv` işlemini tersine çevir."""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)
#@save
def transpose_qkv(X, num_heads):
    """Çoklu dikkat kafasının paralel hesaplaması için aktarım."""
    # `X` girdisinin şekli:
    # (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`).
    # `X` çıktısının şekli:
    # (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_heads`,
    # `num_hiddens` / `num_heads`)
    X = tf.reshape(X, shape=(X.shape[0], X.shape[1], num_heads, -1))

    # `X` çıktısının şekli:
    # (`batch_size`, `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
    # `num_hiddens` / `num_heads`)
    X = tf.transpose(X, perm=(0, 2, 1, 3))

    # `output`'un şekli:
    # (`batch_size` * `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
    # `num_hiddens` / `num_heads`)
    return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3]))


#@save
def transpose_output(X, num_heads):
    """`transpose_qkv` işlemini tersine çevir."""
    X = tf.reshape(X, shape=(-1, num_heads, X.shape[1], X.shape[2]))
    X = tf.transpose(X, perm=(0, 2, 1, 3))
    return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1))

Anahtarların ve değerlerin aynı olduğu bir basit örneği kullanarak uygulanan MultiHeadAttention sınıfını test edelim. Sonuç olarak, çoklu kafalı dikkat çıktısının şekli (batch_size, num_queries, num_hiddens) şeklindedir.

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()

batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
Y = np.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
(2, 4, 100)
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)

batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
Y = tf.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens, training=False).shape
TensorShape([2, 4, 100])

10.5.3. Özet

  • Çoklu kafalı dikkat, sorguların, anahtarların ve değerlerin farklı temsil altuzayları aracılığıyla aynı dikkat ortaklama bilgisini birleştirir.

  • Çoklu kafalı dikkatin çoklu kafasını paralel olarak hesaplamak için uygun tensör düzenlemeleri gereklidir.

10.5.4. Alıştırmalar

  1. Bu deneydeki çoklu kafanın dikkat ağırlıklarını görselleştirin.

  2. Çoklu kafa dikkatine dayalı eğitilmiş bir modelimiz olduğunu ve tahmin hızını artırmak için en az önemli dikkat kafalarını budamak istediğimizi varsayalım. Bir dikkat kafasının önemini ölçmek için deneyleri nasıl tasarlayabiliriz.