Sinir Stil Transferi
====================
Eğer bir fotoğraf meraklısı iseniz, filtreye aşina olabilirsiniz.
Fotoğrafların renk stilini değiştirebilir, böylece manzara fotoğrafları
daha keskin hale gelir veya portre fotoğrafları beyaz tonlara sahip
olur. Ancak, bir filtre genellikle fotoğrafın yalnızca bir yönünü
değiştirir. Bir fotoğrafa ideal bir stil uygulamak için muhtemelen
birçok farklı filtre kombinasyonunu denemeniz gerekir. Bu işlem, bir
modelin hiper parametrelerini ayarlamak kadar karmaşıktır.
Bu bölümde, bir imgenin stilini otomatik olarak başka bir imgeye (örn.
*stil aktarımı* :cite:`Gatys.Ecker.Bethge.2016`) uygulamak için
CNN'nin katmanlı temsillerinden yararlanacağız. Bu görevde iki girdi
imgesi gerekir: Biri *içerik imgesi*, diğeri ise *stil imgesi*. İçerik
imgesini stil imgesine yakın hale getirmek için sinir ağlarını
kullanacağız. Örneğin, :numref:`fig_style_transfer` içindeki içerik
imgesi Seattle'ın banliyölerindeki Rainier Milli Parkı'nda tarafımızdan
çekilen bir manzara fotoğrafıdır ve stil imgesi sonbahar meşe ağaçları
temalı bir yağlı boya tablosudur. Sentezlenmiş çıktı imgesinde, stil
imgesinin yağlı fırça darbeleri uygulanarak, içerik imgesindeki
nesnelerin ana şekli korunurken daha canlı renkler elde edilir.
.. _fig_style_transfer:
.. figure:: ../img/style-transfer.svg
Verilen içerik ve stil imgeleri, stil aktarımı sentezlenmiş bir imge
verir.
Yöntem
------
:numref:`fig_style_transfer_model`, CNN tabanlı stil aktarım yöntemini
basitleştirilmiş bir örnekle gösterir. İlk olarak, sentezlenen imgeyi,
örneğin içerik imgesine ilkleriz. Bu sentezlenen imge, stil aktarımı
işlemi sırasında güncellenmesi gereken tek değişkendir, yani eğitim
sırasında güncellenecek model parametreleri. Daha sonra imge
özniteliklerini ayıklamak için önceden eğitilmiş bir CNN seçiyoruz ve
eğitim sırasında model parametrelerini donduruyoruz. Bu derin CNN
imgeler için hiyerarşik öznitelikleri ayıklamak için çoklu katman
kullanır. İçerik öznitelikleri veya stil öznitelikleri olarak bu
katmanlardan bazılarının çıktısını seçebiliriz. Örnek olarak
:numref:`fig_style_transfer_model` figürünü ele alın. Buradaki önceden
eğitilmiş sinir ağı, ikinci katmanın içerik özniteliklerini çıkardığı ve
birinci ve üçüncü katmanlar stil özniteliklerini çıkardığı 3 evrişimli
katmana sahiptir.
.. _fig_style_transfer_model:
.. figure:: ../img/neural-style.svg
CNN tabanlı stil aktarım süreci. Düz çizgiler ileri yayma yönünü ve
noktalı çizgiler geriye yaymayı gösterir.
Daha sonra, ileri yayma yoluyla stil aktarımının kayıp işlevini
hesaplarız (katı okların yönü) ve model parametrelerini (çıktı için
sentezlenmiş imge) geri yayma (kesikli okların yönü) ile güncelleriz.
Stil aktarımında yaygın olarak kullanılan kayıp fonksiyonu üç bölümden
oluşur: (i) *içerik kaybı* sentezlenen imgeyi ve içerik imgesini içerik
özniteliklerinde yakınlaştırır; (ii) *stil kaybı* sentezlenen imge ve
stil imgesini stil özniteliklerinde yakınlaştırır; ve (iii) *toplam
değişim kaybı* sentezlenen imgede gürültü azaltmaya yardım eder. Son
olarak, model eğitimi bittiğinde, son sentezlenmiş imgeyi oluşturmak
için stil aktarımının model parametrelerini çıktı olarak veririz.
Aşağıda, somut bir deney yoluyla stil aktarımının teknik detaylarını
açıklayacağız.
İçerik ve Stil İmgelerini Okuma
-------------------------------
İlk olarak, içerik ve stil imgelerini okuyoruz. Basılı koordinat
eksenlerinden, bu imgelerin farklı boyutlarda olduğunu söyleyebiliriz.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
%matplotlib inline
from d2l import mxnet as d2l
from mxnet import autograd, gluon, image, init, np, npx
from mxnet.gluon import nn
npx.set_np()
d2l.set_figsize()
content_img = image.imread('../img/rainier.jpg')
d2l.plt.imshow(content_img.asnumpy());
.. figure:: output_neural-style_5de8ca_3_0.svg
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
style_img = image.imread('../img/autumn-oak.jpg')
d2l.plt.imshow(style_img.asnumpy());
.. figure:: output_neural-style_5de8ca_4_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
d2l.set_figsize()
content_img = d2l.Image.open('../img/rainier.jpg')
d2l.plt.imshow(content_img);
.. figure:: output_neural-style_5de8ca_7_0.svg
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
style_img = d2l.Image.open('../img/autumn-oak.jpg')
d2l.plt.imshow(style_img);
.. figure:: output_neural-style_5de8ca_8_0.svg
.. raw:: html
.. raw:: html
Ön İşleme ve Sonradan İşleme
----------------------------
Aşağıda, ön işleme ve sonradan işleme imgeleri için iki işlev
tanımlıyoruz. ``preprocess`` işlevi, girdi imgesinin üç RGB kanalının
her birini standartlaştırır ve sonuçları CNN girdi biçimine dönüştürür.
``postprocess`` işlevi, standartlaştırmadan önce çıktı imgesinde piksel
değerlerini orijinal değerlerine geri yükler. İmge yazdırma işlevi, her
pikselin 0'dan 1'e kadar kayan virgüllü sayı değerine sahip olmasını
gerektirdiğinden, 0'dan küçük veya 1'den büyük herhangi bir değeri
sırasıyla 0 veya 1 ile değiştiririz.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])
def preprocess(img, image_shape):
img = image.imresize(img, *image_shape)
img = (img.astype('float32') / 255 - rgb_mean) / rgb_std
return np.expand_dims(img.transpose(2, 0, 1), axis=0)
def postprocess(img):
img = img[0].as_in_ctx(rgb_std.ctx)
return (img.transpose(1, 2, 0) * rgb_std + rgb_mean).clip(0, 1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])
def preprocess(img, image_shape):
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_shape),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
return transforms(img).unsqueeze(0)
def postprocess(img):
img = img[0].to(rgb_std.device)
img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
.. raw:: html
.. raw:: html
Öznitelik Ayıklama
------------------
İmge özniteliklerini :cite:`Gatys.Ecker.Bethge.2016` ayıklamak için
ImageNet veri kümesinde önceden eğitilmiş VGG-19 modelini kullanıyoruz.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
pretrained_net = gluon.model_zoo.vision.vgg19(pretrained=True)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
pretrained_net = torchvision.models.vgg19(pretrained=True)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
/home/d2l-worker/miniconda3/envs/d2l-tr-release-0/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
warnings.warn(
/home/d2l-worker/miniconda3/envs/d2l-tr-release-0/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
.. raw:: html
.. raw:: html
İmgenin içerik özniteliklerini ve stil özniteliklerini ayıklamak için,
VGG ağındaki belirli katmanların çıktısını seçebiliriz. Genel olarak,
girdi katmanına ne kadar yakın olursa, imgenin ayrıntılarını çıkarmak
daha kolay olur ve tersi yönde de, imgenin küresel bilgilerini çıkarmak
daha kolay olur. Sentezlenen imgede içerik imgesinin ayrıntılarını aşırı
derecede tutmaktan kaçınmak için imgenin içerik özniteliklerinin
çıktısını almak için *içerik katmanı* olarak çıktıya daha yakın bir VGG
katmanı seçiyoruz. Yerel ve küresel stil özniteliklerini ayıklamak için
farklı VGG katmanlarının çıktısını da seçiyoruz. Bu katmanlar *stil
katmanları* olarak da adlandırılır. :numref:`sec_vgg` içinde
belirtildiği gibi, VGG ağı 5 evrişimli blok kullanır. Deneyde, dördüncü
evrişimli bloğun son evrişimli katmanını içerik katmanı olarak ve her
bir evrişimli bloğun ilk evrişimli katmanını stil katmanı olarak
seçiyoruz. Bu katmanların indeksleri ``pretrained_net`` örneğini
yazdırarak elde edilebilir.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
.. raw:: html
.. raw:: html
VGG katmanlarını kullanarak öznitelikleri ayıklarken, yalnızca girdi
katmanından içerik katmanına veya çıktı katmanına en yakın stil
katmanına kadar tüm bunları kullanmamız gerekir. Yalnızca öznitelik
ayıklama için kullanılacak tüm VGG katmanlarını koruyan yeni bir ağ
örneği, ``net``, oluşturalım.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential()
for i in range(max(content_layers + style_layers) + 1):
net.add(pretrained_net.features[i])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential(*[pretrained_net.features[i] for i in
range(max(content_layers + style_layers) + 1)])
.. raw:: html
.. raw:: html
``X`` girdisi göz önüne alındığında, sadece ileri yayma ``net(X)``'i
çağırırsak, yalnızca son katmanın çıktısını alabiliriz. Ara katmanların
çıktılarına da ihtiyacımız olduğundan, katman bazında hesaplama yapmalı,
içerik ve stil katmanı çıktılarını korumalıyız.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def extract_features(X, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents, styles
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def extract_features(X, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents, styles
.. raw:: html
.. raw:: html
Aşağıda iki işlev tanımlanmıştır: ``get_contents`` işlevi içerik
imgesinden içerik özniteliklerini ayıklar ve ``get_styles`` işlevi stil
imgesinden stil özniteliklerini ayıklar. Eğitim sırasında önceden
eğitilmiş VGG'nin model parametrelerini güncellemeye gerek olmadığından,
eğitim başlamadan bile içerik ve stil özniteliklerini ayıklayabiliriz.
Sentezlenen imge, stil aktarımı için güncellenecek bir model
parametreleri kümesi olduğundan, yalnızca sentezlenen imgenin içerik ve
stil özniteliklerini eğitim sırasında ``extract_features`` işlevini
çağırarak ayıklayabiliriz.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def get_contents(image_shape, device):
content_X = preprocess(content_img, image_shape).copyto(device)
contents_Y, _ = extract_features(content_X, content_layers, style_layers)
return content_X, contents_Y
def get_styles(image_shape, device):
style_X = preprocess(style_img, image_shape).copyto(device)
_, styles_Y = extract_features(style_X, content_layers, style_layers)
return style_X, styles_Y
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def get_contents(image_shape, device):
content_X = preprocess(content_img, image_shape).to(device)
contents_Y, _ = extract_features(content_X, content_layers, style_layers)
return content_X, contents_Y
def get_styles(image_shape, device):
style_X = preprocess(style_img, image_shape).to(device)
_, styles_Y = extract_features(style_X, content_layers, style_layers)
return style_X, styles_Y
.. raw:: html
.. raw:: html
Kayıp Fonksiyonunu Tanımlama
----------------------------
Şimdi stil aktarımı için kayıp işlevini açıklayacağız. Kayıp fonksiyonu
içerik kaybı, stil kaybı ve toplam değişim kaybından oluşur.
İçerik Kaybı
~~~~~~~~~~~~
Doğrusal bağlanımdaki kayıp işlevine benzer şekilde, içerik kaybı, kare
kayıp fonksiyonu aracılığıyla, sentezlenen imge ile içerik imgesi
arasındaki içerik özniteliklerindeki farkı ölçer. Kare kayıp
fonksiyonunun iki girdisi, ``extract_features`` fonksiyonu tarafından
hesaplanan içerik katmanının çıktılarıdır.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def content_loss(Y_hat, Y):
return np.square(Y_hat - Y).mean()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def content_loss(Y_hat, Y):
# We detach the target content from the tree used to dynamically compute
# the gradient: this is a stated value, not a variable. Otherwise the loss
# will throw an error.
return torch.square(Y_hat - Y.detach()).mean()
.. raw:: html
.. raw:: html
Stil Kaybı
~~~~~~~~~~
İçerik kaybına benzer şekilde stil kaybı, aynı zamanda sentezlenen imge
ile stil imgesi arasındaki stil farkını ölçmek için kare kaybı işlevini
kullanır. Herhangi bir stil katmanının stil çıktısını ifade etmeden önce
stil katmanı çıktısını hesaplamak için ``extract_features`` işlevini
kullanırız. Çıktının 1 örneği, :math:`c` kanalları, :math:`h` yüksekliği
ve :math:`w` genişliği olduğunu varsayalım, bu çıktıyı :math:`c`
satırları ve :math:`hw` sütunlarıyla :math:`\mathbf{X}` matrisine
dönüştürebiliriz. Bu matris, her birinin uzunluğu :math:`hw` olan
:math:`c` adet :math:`\mathbf{x}_1, \ldots, \mathbf{x}_c` vektörlerinin
birleşimi olarak düşünülebilir. Burada :math:`\mathbf{x}_i` vektörü,
:math:`i` kanalının stil özniteliğini temsil eder.
Bu vektörlerin *Gram matrisinde*
:math:`\mathbf{X}\mathbf{X}^\top \in \mathbb{R}^{c \times c}`, :math:`i`
satırındaki ve :math:`j` sütunundaki :math:`x_{ij}` öğesi
:math:`\mathbf{x}_j` vektörlerinin nokta çarpımıdır. :math:`i` ve
:math:`j` kanallarının stil özniteliklerinin korelasyonunu temsil eder.
Bu Gram matrisini herhangi bir stil katmanının stil çıktısını temsil
etmek için kullanırız. :math:`hw` değeri daha büyük olduğunda, büyük
olasılıkla Gram matrisinde daha büyük değerlere yol açtığını unutmayın.
Gram matrisinin yüksekliği ve genişliğinin ikisinin de kanal sayısının
:math:`c` olduğunu da unutmayın. Stil kaybının bu değerlerden
etkilenmemesine izin vermek için, aşağıdaki ``gram`` işlevi Gram
matrisini elemanlarının sayısına (yani :math:`chw`) böler.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def gram(X):
num_channels, n = X.shape[1], d2l.size(X) // X.shape[1]
X = X.reshape((num_channels, n))
return np.dot(X, X.T) / (num_channels * n)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def gram(X):
num_channels, n = X.shape[1], X.numel() // X.shape[1]
X = X.reshape((num_channels, n))
return torch.matmul(X, X.T) / (num_channels * n)
.. raw:: html
.. raw:: html
Açıkçası, stil kaybı için kare kayıp fonksiyonunun iki Gram matris
girdisi, sentezlenen imge ve stil imgesi için stil katmanı çıktılarına
dayanır. Burada stil imgesine dayanan Gram matrisi ``gram_Y``'nin
önceden hesaplandığı varsayılmaktadır.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def style_loss(Y_hat, gram_Y):
return np.square(gram(Y_hat) - gram_Y).mean()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def style_loss(Y_hat, gram_Y):
return torch.square(gram(Y_hat) - gram_Y.detach()).mean()
.. raw:: html
.. raw:: html
Toplam Değişim Kaybı
~~~~~~~~~~~~~~~~~~~~
Bazen, öğrenilen sentezlenen imge çok yüksek frekanslı gürültüye, yani
özellikle parlak veya koyu piksellere sahiptir. Bir yaygın gürültü
azaltma yöntemi *toplam değişim gürültü arındırma*\ dır. :math:`(i, j)`
koordinatında piksel değerini :math:`x_{i, j}` ile belirtin. Toplam
değişim kaybının azaltma
.. math:: \sum_{i, j} \left|x_{i, j} - x_{i+1, j}\right| + \left|x_{i, j} - x_{i, j+1}\right|
, sentezlenen imgedeki komşu piksellerin değerlerini yakınlaştırır.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def tv_loss(Y_hat):
return 0.5 * (np.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
np.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def tv_loss(Y_hat):
return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
.. raw:: html
.. raw:: html
Kayıp Fonksiyonu
~~~~~~~~~~~~~~~~
Stil aktarımının kayıp fonksiyonu, içerik kaybı, stil kaybı ve toplam
değişim kaybının ağırlıklı toplamıdır. Bu ağırlık hiper parametrelerini
ayarlayarak sentezlenen imgede, içerik tutma, stil aktarma ve gürültü
azaltma arasında denge kurabiliriz.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
content_weight, style_weight, tv_weight = 1, 1e3, 10
def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
# Sırasıyla içerik, stil ve toplam varyans kayıplarını hesaplayın
contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
contents_Y_hat, contents_Y)]
styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
styles_Y_hat, styles_Y_gram)]
tv_l = tv_loss(X) * tv_weight
# Bütün kayıpları toplayın
l = sum(10 * styles_l + contents_l + [tv_l])
return contents_l, styles_l, tv_l, l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
content_weight, style_weight, tv_weight = 1, 1e3, 10
def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
# Sırasıyla içerik, stil ve toplam varyans kayıplarını hesaplayın
contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
contents_Y_hat, contents_Y)]
styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
styles_Y_hat, styles_Y_gram)]
tv_l = tv_loss(X) * tv_weight
# Bütün kayıpları toplayın
l = sum(10 * styles_l + contents_l + [tv_l])
return contents_l, styles_l, tv_l, l
.. raw:: html
.. raw:: html
Sentezlenmiş İmgeyi İlkleme
---------------------------
Stil transferinde, sentezlenen imge, eğitim sırasında güncellenmesi
gereken tek değişkendir. Böylece, basit bir model, ``SynthesizedImage``
tanımlayabilir ve sentezlenen imgeyi model parametreleri olarak ele
alabiliriz. Bu modelde, ileri yayma sadece model parametrelerini
döndürür.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class SynthesizedImage(nn.Block):
def __init__(self, img_shape, **kwargs):
super(SynthesizedImage, self).__init__(**kwargs)
self.weight = self.params.get('weight', shape=img_shape)
def forward(self):
return self.weight.data()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class SynthesizedImage(nn.Module):
def __init__(self, img_shape, **kwargs):
super(SynthesizedImage, self).__init__(**kwargs)
self.weight = nn.Parameter(torch.rand(*img_shape))
def forward(self):
return self.weight
.. raw:: html
.. raw:: html
Sonrasında, ``get_inits`` işlevini tanımlıyoruz. Bu işlev, sentezlenmiş
bir imge modeli örneği oluşturur ve ``X`` imgesine ilkler. Çeşitli stil
katmanlarındaki stil imgesi için Gram matrisleri, ``styles_Y_gram``,
eğitimden önce hesaplanır.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def get_inits(X, device, lr, styles_Y):
gen_img = SynthesizedImage(X.shape)
gen_img.initialize(init.Constant(X), ctx=device, force_reinit=True)
trainer = gluon.Trainer(gen_img.collect_params(), 'adam',
{'learning_rate': lr})
styles_Y_gram = [gram(Y) for Y in styles_Y]
return gen_img(), styles_Y_gram, trainer
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def get_inits(X, device, lr, styles_Y):
gen_img = SynthesizedImage(X.shape).to(device)
gen_img.weight.data.copy_(X.data)
trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
styles_Y_gram = [gram(Y) for Y in styles_Y]
return gen_img(), styles_Y_gram, trainer
.. raw:: html
.. raw:: html
Eğitim
------
Modeli stil aktarımı için eğitirken, sentezlenen imgenin içerik
özniteliklerini ve stil özniteliklerini sürekli olarak ayıklarız ve
kayıp işlevini hesaplarız. Aşağıda eğitim döngüsünü tanımlıyoruz.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[10, num_epochs], ylim=[0, 20],
legend=['content', 'style', 'TV'],
ncols=2, figsize=(7, 2.5))
for epoch in range(num_epochs):
with autograd.record():
contents_Y_hat, styles_Y_hat = extract_features(
X, content_layers, style_layers)
contents_l, styles_l, tv_l, l = compute_loss(
X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
l.backward()
trainer.step(1)
if (epoch + 1) % lr_decay_epoch == 0:
trainer.set_learning_rate(trainer.learning_rate * 0.8)
if (epoch + 1) % 10 == 0:
animator.axes[1].imshow(postprocess(X).asnumpy())
animator.add(epoch + 1, [float(sum(contents_l)),
float(sum(styles_l)), float(tv_l)])
return X
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[10, num_epochs],
legend=['content', 'style', 'TV'],
ncols=2, figsize=(7, 2.5))
for epoch in range(num_epochs):
trainer.zero_grad()
contents_Y_hat, styles_Y_hat = extract_features(
X, content_layers, style_layers)
contents_l, styles_l, tv_l, l = compute_loss(
X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
l.backward()
trainer.step()
scheduler.step()
if (epoch + 1) % 10 == 0:
animator.axes[1].imshow(postprocess(X))
animator.add(epoch + 1, [float(sum(contents_l)),
float(sum(styles_l)), float(tv_l)])
return X
.. raw:: html
.. raw:: html
Şimdi modeli eğitmeye başlıyoruz. İçerik ve stil imgelerinin
yüksekliğini ve genişliğini 300 x 450 piksele yeniden ölçeklendiriyoruz.
Sentezlenen imgeyi ilklemek için içerik imgesini kullanırız.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
device, image_shape = d2l.try_gpu(), (450, 300)
net.collect_params().reset_ctx(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.9, 500, 50)
.. figure:: output_neural-style_5de8ca_140_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
device, image_shape = d2l.try_gpu(), (300, 450) # PIL Image (h, w)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
.. figure:: output_neural-style_5de8ca_143_0.svg
.. raw:: html
.. raw:: html
Sentezlenen imgenin içerik imgesinin manzarasını ve nesnelerini
koruduğunu ve aynı zamanda stil imgesinin rengini aktardığını
görebiliriz. Örneğin, sentezlenen imgenin stil imgesinde olduğu gibi
renk blokları vardır. Bu blokların bazıları fırça darbelerinin ince
dokusuna bile sahiptir.
Özet
----
- Stil aktarımında yaygın olarak kullanılan kayıp fonksiyonu üç
bölümden oluşur: (i) İçerik kaybı, sentezlenen imgeyi ve içerik
imgesini içerik özniteliklerinde yakınlaştırır; (ii) stil kaybı,
sentezlenen imge ve stil imgesini stil özniteliklerinde
yakınlaştırır; ve (iii) toplam değişim kaybı sentezlenmiş imgedeki
gürültüyü azaltmayı sağlar.
- Eğitim sırasında sentezlenen imgeyi sürekli olarak model
parametreleri olarak güncellemek için imge özniteliklerini ayıklamak
ve kayıp işlevini en aza indirmek için önceden eğitilmiş bir CNN
kullanabiliriz.
- Stil katmanlarından stil çıktılarını temsil etmek için Gram
matrisleri kullanırız.
Alıştırmalar
------------
1. Farklı içerik ve stil katmanları seçtiğinizde çıktı nasıl değişir?
2. Kayıp işlevindeki ağırlık hiper parametrelerini ayarlayın. Çıktıda
daha fazla içerik mi yoksa daha az gürültü mü var?
3. Farklı içerik ve stil imgeleri kullanın. Sentezlenmiş daha ilginç
imgeler oluşturabilir misiniz?
4. Metin için stil aktarımı uygulayabilir miyiz? İpucu: Hu ve
arkadaşlarının araştırma makalesine başvurabilirsiniz
:cite:`Hu.Lee.Aggarwal.ea.2020`.
.. raw:: html
.. raw:: html
`Tartışmalar `__
.. raw:: html
.. raw:: html
`Tartışmalar `__
.. raw:: html
.. raw:: html