圖像風格遷移其實非常好了解,就是将一張圖像的“風格”(風格圖像)遷移至另外一張圖像(内容圖像),但是這所謂的另外一張圖像隻是在“風格”上與之前有所不同,圖像的“内容”仍要與之前相同。
Luan et al. and Gatys et al.
的工作都是利用
VGGNet19
作為該項任務的
backbone
,由于
VGGNet19
是一種近似“金字塔”型結構,是以随着卷積操作的加深,
feature maps
的感受野越來越大,提取到的圖像特征從局部擴充到了全局。我們為了避免合成的圖像過多地保留内容資訊,選取
VGGNet19
中位于金字塔頂部的卷積層作為内容層。整個訓練過程為将生成圖像初始化為内容圖像,每次循環分别抽取生成圖像和内容圖像的内容特征,計算mse并且使之最小化,同時抽取生成圖像和風格圖像的樣式特征,計算mse并且使之最小化。這裡注意損失函數的寫法:

總損失由兩部分組成:内容損失和樣式損失。内容損失即為生成圖像和内容圖像對應特征圖的均方誤差,但是樣式損失需要分别計算生成圖像和内容圖像的格拉姆矩陣再做均方誤差。另外, α \alpha α和 β \beta β分别為内容損失和樣式損失的各項權重, Γ \Gamma Γ為樣式損失的懲罰系數。我通過實發現 β \beta β和 Γ \Gamma Γ應該取的值大些,使得樣式損失被盡可能地“懲罰”,即“放大”樣式損失。
import torch
import numpy as np
from PIL import Image
from torchvision.models import vgg19
from torchvision.transforms import transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.nn.functional import mse_loss
from torch.autograd import Variable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 預處理:大小裁剪、轉為張量、歸一化
def preprocess(img_shape):
transform = transforms.Compose([
transforms.Resize(img_shape),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
return transform
class VGGNet19(nn.Module):
def __init__(self):
super(VGGNet19, self).__init__()
self.vggnet19 = vgg19(pretrained=False)
self.vggnet19.load_state_dict(torch.load('./vgg19-dcbb9e9d.pth'))
self.content_layers = [25]
self.style_layers = [0, 5, 10, 19, 28]
def forward(self, x):
content_features = []
style_features = []
for name, module in self.vggnet19.features._modules.items():
x = module(x)
if int(name) in self.content_layers:
content_features.append(x)
if int(name) in self.style_layers:
style_features.append(x)
return content_features, style_features
class GenerateImage(nn.Module):
def __init__(self, img_shape):
super(GenerateImage, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(*img_shape))
def forward(self):
return self.weight
# 初始化生成圖像為内容圖像
def generate_inits(content, device, lr):
g_img = GenerateImage(content.shape).to(device)
g_img.weight.data = content.data
optimizer = torch.optim.Adam(g_img.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
return g_img(), optimizer
# 計算格拉姆矩陣
def gramMatrix(x):
_, c, h, w = x.shape
x = x.view(c, h*w)
return torch.matmul(x, x.t()) / (c*h*w)
# 計算總損失:内容損失+樣式損失
def compute_loss(content_g, content_y, style_g, style_y, content_weight, style_weight, gamma):
contentlosses = [mse_loss(g, y)*content_weight for g, y in zip(content_g, content_y)]
stylelosses = [mse_loss(gramMatrix(g), gramMatrix(y))*style_weight for g, y in zip(style_g, style_y)]
total_loss = sum(contentlosses) + gamma * sum(stylelosses)
return contentlosses, stylelosses, total_loss
# 用于可視化的後處理
def postprocess(img_tensor):
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])
inv_normalize = transforms.Normalize(
mean=-rgb_mean/rgb_std,
std=1/rgb_std)
to_PIL_image = transforms.ToPILImage()
return to_PIL_image(inv_normalize(img_tensor[0].detach().cpu()).clamp(0, 1))
def train(lr, epoch_num, c_path, s_path, img_shape):
ipt = Image.open(c_path)
syl = Image.open(s_path)
transform = preprocess(img_shape)
content, style = transform(ipt).unsqueeze(0), transform(syl).unsqueeze(0)
net = VGGNet19()
net.to(device).eval()
content = content.type(torch.FloatTensor)
style = style.type(torch.FloatTensor)
if torch.cuda.is_available():
content, style = Variable(content.cuda(), requires_grad=False), Variable(style.cuda(), requires_grad=False)
else:
content, style = Variable(content, requires_grad=False), Variable(style, requires_grad=False)
icontent, istyle = net(content)
scontent, sstyle = net(style)
input, optimizer = generate_inits(content, device, lr)
for epoch in range(epoch_num+1):
gcontent, gstyle = net(input)
contentlosses, stylelosses, total_loss = compute_loss(gcontent, icontent, gstyle, sstyle, 1, 1e3, 1e2)
optimizer.zero_grad()
total_loss.backward(retain_graph=True)
optimizer.step()
print("[epoch: %3d/%3d] content loss: %3f style loss: %3f total loss: %3f" % (epoch, epoch_num, sum(contentlosses).item(), sum(stylelosses).item(), total_loss))
if epoch % 100 == 0 and epoch != 0:
# plt.imshow(postprocess(input))
# plt.axis('off')
# plt.show()
torch.save(net.state_dict(), "itr_%d_total_loss_%3f.pth" % (epoch, total_loss))
if __name__ == "__main__":
train(0.01, 10000, './content.jpg', './s.jpg', (500, 700)
内容圖像、風格圖像和生成圖像(第10000次疊代的可視化)分别如上圖所示,并且代碼實作是
Gatys et al.
的工作。