天天看點

訓練一個資料不夠多的資料集是什麼體驗?

摘要:這裡介紹其中一種帶标簽擴充資料集的方法。

前言

前一段時間接觸了幾位使用者提的問題,發現很多人在使用訓練的時候,給的資料集寥寥無幾,有一些甚至一類隻有5張圖檔。modelarts平台雖然給出了每類5張圖檔就能訓練的限制,但是這種限制對一個工業級的應用場景往往是遠遠不夠的。是以聯系了使用者希望多增加一些圖檔,增加幾千張圖檔訓練。但是使用者後面回報,标注的工作量實在是太大了。我思忖了一下,分析了一下他應用的場景,做了一些政策變化。這裡介紹其中一種帶标簽擴充資料集的方法。

資料集情況

資料集由于屬于使用者資料,不能随便展示,這裡用一個可以展示的開源資料集來替代。首先,這是一個分類的問題,需要檢測出工業零件表面的瑕疵,判斷是否為殘次品,如下是樣例圖檔:

訓練一個資料不夠多的資料集是什麼體驗?

這是兩塊太陽能電闆的表面,左側是正常的,右側是有殘缺和殘次現象的,我們需要用一個模型來區分這兩類的圖檔,幫助定位哪些太陽能電闆存在問題。左側的正常樣本754張,右側的殘次樣本358張,驗證集同樣,正常樣本754張,殘次樣本357張。總樣本在2000張左右,對于一般工業要求的95%以上準确率模型而言屬于一個非常小的樣本。先直接拿這個資料集用Pytorch加載imagenet的resnet50模型訓練了一把,整體精度ACC在86.06%左右,召回率正常類為97.3%,但非正常類為62.9%,還不能達到使用者預期。

當要求使用者再多收集,至少擴充到萬級的資料集的時候,使用者提出,收集資料要經過處理,還要标注,很麻煩,問有沒有其他的辦法可以節省一些工作量。這可一下難倒了我,資料可是深度學習訓練的靈魂,這可咋整啊。

訓練一個資料不夠多的資料集是什麼體驗?

仔細思考了一陣子,想到modelarts上有智能标注然後人工校驗的功能,就讓使用者先試着體驗一下這個功能。我這邊拿他給我的資料集想想辦法。查了些資料,小樣本學習few-shot fewshot learning (FSFSL)的常見方法,基本都是從兩個方向入手。一是資料本身,二是從模型訓練本身,也就是對圖像提取的特征做文章。這裡想着從資料本身入手。

首先觀察資料集,都是300*300的灰階圖像,而且都已太陽能電闆表面的正面俯視為整張圖檔。這屬于預先處理的很好的圖檔。那麼針對這種圖檔,翻轉鏡像對圖檔整體結構影響不大,是以我們首先可以做的就是flip操作,增加資料的多樣性。flip效果如下:

訓練一個資料不夠多的資料集是什麼體驗?

這樣資料集就從1100張擴增到了2200張,還是不是很多,但是直接觀察資料集已經沒什麼太好的擴充辦法了。這時想到用Modelarts模型評估的功能來評估一下模型對資料的泛化能力。這裡調用了提供的SDK:deep_moxing.model_analysis下面的analyse接口。

def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')
    pred_list = []
    target_list = []
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)
            # 擷取logits輸出結果pred和實際目标的結果target
            pred_list += output.cpu().numpy()[:, :2].tolist()
            target_list += target.cpu().numpy().tolist()
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5), i=i)
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)
        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))
    # 擷取圖檔的存儲路徑name
    name_list = val_loader.dataset.samples
    for idx in range(len(name_list)):
        name_list[idx] = name_list[idx][0]
    analyse(task_type='image_classification', save_path='/home/image_labeled/',
            pred_list=pred_list, label_list=target_list, name_list=name_list)
    return top1.avg      

上段代碼大部分都是Pytorch訓練ImageNet中的驗證部分代碼,需要擷取三個list,模型pred直接結果logits、圖檔實際類别target和圖檔存儲路徑name。然後按如上的調用方法調用analyse接口,會在save_path的目錄下生成一個json檔案,放到Modelarts訓練輸出目錄裡,就能在評估結果裡看到對模型的分析結果。我這裡是線下生成的json檔案再上傳到線上看可視化結果。關于敏感度分析結果如下:

訓練一個資料不夠多的資料集是什麼體驗?

這幅圖的意思是,不同的特征值範圍圖檔分别測試的精度是多少。比如亮度敏感度分析的第一項0%-20%,可以了解為,在圖檔亮度較低的場景下對與0類和其他亮度條件的圖檔相比,精度要低很多。整體來看,主要是為了檢測1類,1類在圖檔的亮度和清晰度兩項上顯得都很敏感,也就是模型不能很好地處理圖檔的這兩項特征變化的圖檔。那這不就是我要擴增資料集的方向嗎?

訓練一個資料不夠多的資料集是什麼體驗?

好的,那麼我就試着直接對全量的資料集做了擴增,得到一個正常類2210張,瑕疵類1174張圖檔的資料集,用同樣的政策扔進pytorch中訓練,得到的結果:

訓練一個資料不夠多的資料集是什麼體驗?

怎麼回事,和設想的不太一樣啊。。。

訓練一個資料不夠多的資料集是什麼體驗?

重新分析一下資料集,我突然想到,這種工業類的資料集往往都存在一個樣本不均勻的問題,這裡雖然接近2:1,但是檢測的要求針對有瑕疵的類别的比較高,應該讓模型傾向于有瑕疵類去學習,而且看到1類的也就是有瑕疵類的結果比較敏感,是以其實還是存在樣本不均衡的情況。由此後面的這兩種增強方法隻針對了1類也就是有問題的破損類做,最終得到3000張左右,1508張正常類圖檔,1432張有瑕疵類圖檔,這樣樣本就相對平衡了。用同樣的政策扔進resnet50中訓練。最終得到的精度資訊:

訓練一個資料不夠多的資料集是什麼體驗?

可以看到,同樣在驗證集,正常樣本754張,殘次樣本357張的樣本上,Acc1的精度整體提升了接近3%,重要名額殘次類的recall提升了8.4%!嗯,很不錯。是以直接擴充資料集的方法很有效,而且結合模型評估能讓我參考哪些擴增的方法是有意義的。當然還有很重要的一點,要排除原始資料集存在的問題,比如這裡存在的樣本不均衡問題,具體情況具體分析,這個擴增的方法就會變得簡單實用。

之後基于這個實驗的結果和資料集。給幫助使用者改了一些訓練政策,換了個更厲害的網絡,就達到了使用者的要求,當然這都是定制化分析的結果,這裡不詳細展開說明了,或者會在以後的部落格中更新。

引用資料集來自:

Buerhop-Lutz, C.; Deitsch, S.; Maier, A.; Gallwitz, F.; Berger, S.; Doll, B.; Hauch, J.; Camus, C. & Brabec, C. J. A Benchmark for Visual Identification of Defective Solar Cells in Electroluminescence Imagery. European PV Solar Energy Conference and Exhibition (EU PVSEC), 2018. DOI: 10.4229/35thEUPVSEC20182018-5CV.3.15

Deitsch, S.; Buerhop-Lutz, C.; Maier, A. K.; Gallwitz, F. & Riess, C. Segmentation of Photovoltaic Module Cells in Electroluminescence Images. CoRR, 2018, abs/1806.06530

Deitsch, S.; Christlein, V.; Berger, S.; Buerhop-Lutz, C.; Maier, A.; Gallwitz, F. & Riess, C. Automatic classification of defective photovoltaic module cells in electroluminescence images. Solar Energy, Elsevier BV, 2019, 185, 455-468. DOI: 10.1016/j.solener.2019.02.067

點選關注,第一時間了解華為雲新鮮技術~