天天看點

pytorch中使用vutils對多張圖像進行拼接 (import torchvision.utils as vutils)

1.png

pytorch中使用vutils對多張圖像進行拼接 (import torchvision.utils as vutils)

2.png

在pytorch中使用torchvision的vutils函數實作對多張圖檔的拼接。具體操作就是将上面的兩張圖檔,1.png和2.png的多張圖檔進行拼接形成一張圖檔,拼接後的效果如下圖。

給出具體代碼:

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch
import torchvision.utils as vutils

im1=Image.open("1.png").convert("RGB")
im1 = im1.resize((1000, 1000)).rotate(-90)

im2=Image.open("2.png").convert("RGB")
im2 = im2.resize((1000, 1000)).rotate(-90)


# 1000, 1000, 3 => 3, 1000, 1000
images = [np.moveaxis(np.array(im1), 2, 0), np.moveaxis(np.array(im2), 2, 0)]*8


images_tensor = vutils.make_grid(torch.tensor(images)/255.0, nrow=4, padding=0, normalize=True)
print(images_tensor.shape)
# 3, 1000, 1000 => 1000, 1000, 3 
plt.imshow(images_tensor.numpy().transpose((1,2,0)))
plt.show()

vutils.save_image(images_tensor, "3.png")
vutils.save_image(images_tensor, "3_back.png", nrow=2, padding=0, normalize=True)
vutils.save_image(torch.tensor(images)/255.0, "4.png", nrow=8, padding=0, normalize=True)      

=============================================

需要注意的地方:

  • 1.  使用PIL讀入的圖檔要轉為RGB模式,然後要将圖檔對象轉為numpy數組形式,在上面例子中轉為數組後的單張圖檔次元為(1000,1000,3)。
  • 2.  使用vutils.make_grid函數對圖檔進行拼接時,每張圖檔的資料類型都為torch.tensor,并且單張圖檔的格式應為(channel數,長,寬),上面例子中則是(3,1000,1000)。這樣将16張圖檔拼接為每行4張圖檔的大圖後,大圖的次元為(3,4000,4000)。
  • vutils.make_grid函數和vutils.save_image函數接受的pytorch.tensor的類型均為float,如果不能保證資料大小在0和1之間則需要設定正則項normalize=True 。