torchvision.models 里包含了许多模型,用于解决不同的视觉任务:图像分类、语义分割、物体检测、实例分割、人体关键点检测和视频分类。
本文将介绍 torchvision 中模型的入门使用,一起来创建 Faster R-CNN 预训练模型,预测图像中有什么物体吧。
<code>print(model)</code> 可查看其结构:
此预训练模型是于 COCO train2017 上训练的,可预测的分类有:
获取支持的 <code>device</code>:
模型移到 <code>device</code>:
准备模型入参 <code>images</code>:
例图 <code>data/bicycle.jpg</code>:

模型切为 <code>eval</code> 模式:
模型在推断时,只需要给到图像数据,不用标注数据。推断后,会返回每个图像的预测结果 <code>List[Dict[Tensor]]</code>。<code>Dict</code> 包含字段有:
boxes (<code>FloatTensor[N, 4]</code>): 预测框 <code>[x1, y1, x2, y2]</code>, <code>x</code> 范围 <code>[0,W]</code>, <code>y</code> 范围 <code>[0,H]</code>
labels (<code>Int64Tensor[N]</code>): 预测类别
scores (<code>Tensor[N]</code>): 预测评分
预测结果如下:
获取 <code>score >= 0.9</code> 的预测结果:
引入 <code>utils.plots.plot_image</code> 绘制结果:
<code>utils.plots.plot_image</code> 函数实现可见后文源码,注意其要求 <code>torchvision >= 0.9.0/nightly</code>。
test_pretrained_models.py
<code>utils.colors.golden</code>:
<code>utils.plots.plot_image</code>:
torch.hub
torchvision.models
GoCoding 个人实践的经验分享,可关注公众号!