10.5. Çoklu-Kafalı Dikkat¶ Open the notebook in SageMaker Studio Lab
Pratikte, aynı sorgular, anahtarlar ve değerler kümesi verildiğinde, modelimizin, çeşitli aralıkların bir sıra içinde (örneğin, daha kısa menzile karşı daha uzun menzil) bağımlılıklarını yakalama gibi aynı dikkat mekanizmasının farklı davranışlarından elde edilen bilgileri birleştirmesini isteyebiliriz. Bu nedenle, dikkat mekanizmamızın sorguların, anahtarların ve değerlerin farklı temsil alt alanlarını ortaklaşa kullanmasına izin vermek yararlı olabilir.
Bu amaçla, tek bir dikkat ortaklaması yerine, sorgular, anahtarlar ve değerler \(h\) tane bağımsız olarak öğrenilen doğrusal izdüşümler ile dönüştürülebilir. Daha sonra bu \(h\) öngörülen sorgular, anahtarlar ve değerler paralel olarak dikkat ortaklaması içine beslenir. Nihayetinde, \(h\) dikkat ortaklama çıktıları bitiştirilir ve son çıktıyı üretmek için başka bir öğrenilmiş doğrusal izdüşüm ile dönüştürülür. Bu tasarıma çoklu kafalı dikkat denir, burada \(h\) dikkat ortaklama çıktılarının her biri kafadır (Vaswani et al., 2017). Öğrenilebilir doğrusal dönüşümler gerçekleştirmek için tam bağlı katmanları kullanan çoklu kafalı dikkat Fig. 10.5.1 şeklinde açıklanmıştır.
Fig. 10.5.1 Çoklu kafanın bir araya getirildiği ve ardından doğrusal olarak dönüştürüldüğü çoklu kafalı dikkat.¶
10.5.1. Model¶
Çoklu kafalı dikkatin uygulanmasını sağlamadan önce, bu modeli matematiksel olarak biçimlendirelim. Bir sorgu \(\mathbf{q} \in \mathbb{R}^{d_q}\), bir anahtar \(\mathbf{k} \in \mathbb{R}^{d_k}\) ve bir değer \(\mathbf{v} \in \mathbb{R}^{d_v}\) göz önüne alındığında, her dikkat kafası \(\mathbf{h}_i\) (\(i = 1, \ldots, h\)) aşağıdaki gibi hesaplanır
burada öğrenilebilir parametreler \(\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}\), \(\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}\) ve \(\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}\) ve \(f\), Section 10.3 içindeki toplayıcı dikkat ve ölçeklendirilmiş nokta çarpımı dikkat gibi dikkat ortaklamasıdır. Çoklu kafalı dikkat çıktısı, \(h\) kafalarının bitiştirilmesinin \(\mathbf W_o\in\mathbb R^{p_o\times h p_v}\) öğrenilebilir parametreleri vasıtasıyla başka bir doğrusal dönüşümdür:
Bu tasarıma dayanarak, her kafa girdisinin farklı bölümleriyle ilgilenebilir. Basit ağırlıklı ortalamadan daha gelişmiş fonksiyonlar ifade edilebilir.
import math
from d2l import mxnet as d2l
from mxnet import autograd, np, npx
from mxnet.gluon import nn
npx.set_np()
import math
import torch
from torch import nn
from d2l import torch as d2l
import tensorflow as tf
from d2l import tensorflow as d2l
10.5.2. Uygulama¶
Uygulamamızda, çoklu kafalı dikkatin her bir kafa için ölçeklendirilmiş
nokta-çarpımı dikkatini seçiyoruz. Hesaplama maliyetinde ve
parametreleştirme maliyetinde önemli bir artıştan kaçınmak için
\(p_q = p_k = p_v = p_o / h\) olarak ayarladık. Sorgu, anahtar ve
değer için doğrusal dönüşümlerin çıktı sayısını
\(p_q h = p_k h = p_v h = p_o\) olarak ayarlarsak, \(h\) adet
kafanın paralel olarak hesaplanabileceğini unutmayın. Aşağıdaki
uygulamada, \(p_o\), num_hiddens
bağımsız değişkeni aracılığıyla
belirtilir.
#@save
class MultiHeadAttention(nn.Block):
"""Multi-head attention."""
def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
**kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
def forward(self, queries, keys, values, valid_lens):
# `queries`, `keys`, veya `values` şekli:
# (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`)
# `valid_lens`'in şekli:
# (`batch_size`,) or (`batch_size`, no. of queries)
# Devirme sonrası, output `queries`, `keys`, veya `values` şekli:
# (`batch_size` * `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# 0 ekseninde, ilk öğeyi (skaler veya vektör) `num_heads` kez
# kopyalayın, ardından sonraki öğeyi kopyalayın ve devam edin.
valid_lens = valid_lens.repeat(self.num_heads, axis=0)
# `output`'un şekli: (`batch_size` * `num_heads`, no. of queries,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
# `output_concat`'in şekli:
# (`batch_size`, sorgu sayısı, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
#@save
class MultiHeadAttention(nn.Module):
"""Multi-head attention."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# `queries`, `keys`, veya `values` şekli:
# (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`)
# `valid_lens`'in şekli:
# (`batch_size`,) or (`batch_size`, no. of queries)
# Devirme sonrası, output `queries`, `keys`, veya `values` şekli:
# (`batch_size` * `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# 0 ekseninde, ilk öğeyi (skaler veya vektör) `num_heads` kez
# kopyalayın, ardından sonraki öğeyi kopyalayın ve devam edin.
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# `output`'un şekli: (`batch_size` * `num_heads`, no. of queries,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
# `output_concat`'in şekli:
# (`batch_size`, sorgu sayısı, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
#@save
class MultiHeadAttention(tf.keras.layers.Layer):
"""Multi-head attention."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
def call(self, queries, keys, values, valid_lens, **kwargs):
# `queries`, `keys`, veya `values` şekli:
# (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`)
# `valid_lens`'in şekli:
# (`batch_size`,) or (`batch_size`, no. of queries)
# Devirme sonrası, output `queries`, `keys`, veya `values` şekli:
# (`batch_size` * `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# 0 ekseninde, ilk öğeyi (skaler veya vektör) `num_heads` kez
# kopyalayın, ardından sonraki öğeyi kopyalayın ve devam edin.
valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0)
# `output`'un şekli: (`batch_size` * `num_heads`, no. of queries,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens, **kwargs)
# `output_concat`'in şekli: (`batch_size`, sorgu sayısı, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
Çoklu kafanın paralel hesaplanmasına izin vermek için, yukarıdaki
MultiHeadAttention
sınıfı aşağıda tanımlandığı gibi iki devrinim
işlevi kullanır. Özellikle, transpose_output
işlevi
transpose_qkv
işlevinin çalışmasını tersine çevirir.
#@save
def transpose_qkv(X, num_heads):
"""Çoklu dikkat kafasının paralel hesaplaması için aktarım."""
# `X` girdisinin şekli:
# (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`).
# `X` çıktısının şekli:
# (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_heads`,
# `num_hiddens` / `num_heads`)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# `X` çıktısının şekli:
# (`batch_size`, `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
# `num_hiddens` / `num_heads`)
X = X.transpose(0, 2, 1, 3)
# `output`'un şekli:
# (`batch_size` * `num_heads`, anahtar-değer çiftleri veya sorgu sayısı
# `num_hiddens` / `num_heads`)
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""`transpose_qkv` işlemini tersine çevir."""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.transpose(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
#@save
def transpose_qkv(X, num_heads):
"""Çoklu dikkat kafasının paralel hesaplaması için aktarım."""
# `X` girdisinin şekli:
# (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`).
# `X` çıktısının şekli:
# (`batch_size`,anahtar-değer çiftleri veya sorgu sayısı, `num_heads`,
# `num_hiddens` / `num_heads`)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# `X` çıktısının şekli:
# (`batch_size`, `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
# `num_hiddens` / `num_heads`)
X = X.permute(0, 2, 1, 3)
# `output`'un şekli:
# (`batch_size` * `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
# `num_hiddens` / `num_heads`)
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""`transpose_qkv` işlemini tersine çevir."""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
#@save
def transpose_qkv(X, num_heads):
"""Çoklu dikkat kafasının paralel hesaplaması için aktarım."""
# `X` girdisinin şekli:
# (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_hiddens`).
# `X` çıktısının şekli:
# (`batch_size`, anahtar-değer çiftleri veya sorgu sayısı, `num_heads`,
# `num_hiddens` / `num_heads`)
X = tf.reshape(X, shape=(X.shape[0], X.shape[1], num_heads, -1))
# `X` çıktısının şekli:
# (`batch_size`, `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
# `num_hiddens` / `num_heads`)
X = tf.transpose(X, perm=(0, 2, 1, 3))
# `output`'un şekli:
# (`batch_size` * `num_heads`, anahtar-değer çiftleri veya sorgu sayısı,
# `num_hiddens` / `num_heads`)
return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3]))
#@save
def transpose_output(X, num_heads):
"""`transpose_qkv` işlemini tersine çevir."""
X = tf.reshape(X, shape=(-1, num_heads, X.shape[1], X.shape[2]))
X = tf.transpose(X, perm=(0, 2, 1, 3))
return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1))
Anahtarların ve değerlerin aynı olduğu bir basit örneği kullanarak
uygulanan MultiHeadAttention
sınıfını test edelim. Sonuç olarak,
çoklu kafalı dikkat çıktısının şekli (batch_size
, num_queries
,
num_hiddens
) şeklindedir.
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
Y = np.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
(2, 4, 100)
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
Y = tf.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens, training=False).shape
TensorShape([2, 4, 100])
10.5.3. Özet¶
Çoklu kafalı dikkat, sorguların, anahtarların ve değerlerin farklı temsil altuzayları aracılığıyla aynı dikkat ortaklama bilgisini birleştirir.
Çoklu kafalı dikkatin çoklu kafasını paralel olarak hesaplamak için uygun tensör düzenlemeleri gereklidir.
10.5.4. Alıştırmalar¶
Bu deneydeki çoklu kafanın dikkat ağırlıklarını görselleştirin.
Çoklu kafa dikkatine dayalı eğitilmiş bir modelimiz olduğunu ve tahmin hızını artırmak için en az önemli dikkat kafalarını budamak istediğimizi varsayalım. Bir dikkat kafasının önemini ölçmek için deneyleri nasıl tasarlayabiliriz.