天天看點

torch_vision(二):模型和預訓練weight子產品 torchvision.modelstorchvision.models簡單介紹

torchvision.models簡單介紹

介紹

torchvision.models子產品提供了很多模型架構,以及對應的預先訓練好的權重。

最新的版本的特性是相比于舊版本

  1. 一個模型架構可以加載多種不同的權重。
  2. 可以擷取到預處理方法,這些轉換中的任何細微差異(例如插值、調整大小/裁剪大小等)都可能導緻準确性大幅降低或模型無法使用。
  3. 提供中繼資料,包括類别标簽,準确度等名額。

以一個分類模型為例:

from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
# ResNet50_Weights.IMAGENET1K_V1  ResNet50_Weights.IMAGENET1K_V2是其他可以選擇的版本,DEFAULT一般是最優的版本
weights = ResNet50_Weights.DEFAULT 
model = resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")
           

目标檢測

from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]

# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()
           

語義分割

from torchvision.io.image import read_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image

img = read_image("gallery/assets/dog1.jpg")

# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and visualize the prediction
prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()
           

可以選擇的模型和weight

torchvision.models包含很多模型和預先訓練好的weight, 能夠處理多種任務,圖像分類,語義分割,目标檢測,關鍵點檢測,視訊分類,光流估計等。

The torchvision.models subpackage contains definitions of models for addressing different tasks, 
including:image classification, pixelwise semantic segmentation, object detection, instance segmentation, 
person keypoint detection, video classification, and optical flow.
           

具體各個任務有哪些可以在torchvision.models可以擷取到的模型,請檢視

MODELS AND PRE-TRAINED WEIGHTS

其實覆寫的模型不算多,超分,生成模型,圖像修複,圖像增強等多種任務并沒有相關模型在torchvision.models中。

繼續閱讀