天天看点

TorchVision 预训练模型进行推断

torchvision.models 里包含了许多模型,用于解决不同的视觉任务:图像分类、语义分割、物体检测、实例分割、人体关键点检测和视频分类。

本文将介绍 torchvision 中模型的入门使用,一起来创建 Faster R-CNN 预训练模型,预测图像中有什么物体吧。

<code>print(model)</code> 可查看其结构:

此预训练模型是于 COCO train2017 上训练的,可预测的分类有:

获取支持的 <code>device</code>:

模型移到 <code>device</code>:

准备模型入参 <code>images</code>:

例图 <code>data/bicycle.jpg</code>:

TorchVision 预训练模型进行推断

模型切为 <code>eval</code> 模式:

模型在推断时,只需要给到图像数据,不用标注数据。推断后,会返回每个图像的预测结果 <code>List[Dict[Tensor]]</code>。<code>Dict</code> 包含字段有:

boxes (<code>FloatTensor[N, 4]</code>): 预测框 <code>[x1, y1, x2, y2]</code>, <code>x</code> 范围 <code>[0,W]</code>, <code>y</code> 范围 <code>[0,H]</code>

labels (<code>Int64Tensor[N]</code>): 预测类别

scores (<code>Tensor[N]</code>): 预测评分

预测结果如下:

获取 <code>score &gt;= 0.9</code> 的预测结果:

引入 <code>utils.plots.plot_image</code> 绘制结果:

<code>utils.plots.plot_image</code> 函数实现可见后文源码,注意其要求 <code>torchvision &gt;= 0.9.0/nightly</code>。

test_pretrained_models.py

<code>utils.colors.golden</code>:

<code>utils.plots.plot_image</code>:

torch.hub

torchvision.models

GoCoding 个人实践的经验分享,可关注公众号!

继续阅读