天天看點

李宏毅機器學習作業12-Domain Adversarial Training-代碼李宏毅機器學習作業12-Domain Adversarial Training-代碼

李宏毅機器學習作業12-Domain Adversarial Training-代碼

1. 任務描述

  • Domain Adaptation: 讓模型可以在訓練時隻需要 A dataset label,不需要 B dataset label 的情況下提高 B dataset 的準确率。 (A dataset & task 接近 B dataset & task)
  • 給定真實圖檔 & 标簽以及大量的 手繪圖檔,請設計一種方法使得 模型可以預測出手繪圖檔的标簽 為何。
李宏毅機器學習作業12-Domain Adversarial Training-代碼李宏毅機器學習作業12-Domain Adversarial Training-代碼

2. 資料集

  • Training : 5000 張真實圖檔 + label, 32 x 32 RGB
  • Testing : 100000 張手繪圖檔,28 x 28 Gray Scale
  • Label: 總共需要預測 10 個 class,如下圖所示。
李宏毅機器學習作業12-Domain Adversarial Training-代碼李宏毅機器學習作業12-Domain Adversarial Training-代碼

下載下傳位址:

連結:https://pan.baidu.com/s/1gsTpROafCoqYL_PTIlEuhQ

提取碼:kc6r

複制這段内容後打開百度網盤手機App,操作更友善哦

看一下訓練集

import matplotlib.pyplot as plt


def no_axis_show(img, title='', cmap=None):
    # imshow, 縮放模式為nearest。
    fig = plt.imshow(img, interpolation='nearest', cmap=cmap)
    # 不要顯示axis。
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(title)

    


titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']


plt.figure(figsize=(18, 18))
for i in range(10):
    plt.subplot(1, 10, i + 1)
    fig = no_axis_show(plt.imread(f'/kaggle/input/real-or-drawing/real_or_drawing/train_data/{i}/{i*500}.bmp'), title=titles[i])

           
李宏毅機器學習作業12-Domain Adversarial Training-代碼李宏毅機器學習作業12-Domain Adversarial Training-代碼

看一下測試集

plt.figure(figsize=(18, 18))
for i in range(10):
    plt.subplot(1, 10, i + 1)
    fig = no_axis_show(plt.imread(f'/kaggle/input/real-or-drawing/real_or_drawing/test_data/0/0000{i}.bmp'), title='none')
           
李宏毅機器學習作業12-Domain Adversarial Training-代碼李宏毅機器學習作業12-Domain Adversarial Training-代碼

3. 預處理source data

因為大家塗鴉的時候通常隻會畫輪廓,我們可以根據這點将source data做點邊緣偵測處理,讓source data更像target data一點。

Canny Edge Detection

算法這邊不贅述,隻教大家怎麼用。若有興趣歡迎參考wiki或這裡。

cv2.Canny使用非常友善,隻需要兩個參數: low_threshold, high_threshold。

cv2.Canny(image, low_threshold, high_threshold)

簡單來說就是當邊緣值超過high_threshold,我們就确定它是edge。如果隻有超過low_threshold,那就先判斷一下再決定是不是edge。

以下我們直接拿source data做做看。

import cv2

plt.figure(figsize=(18, 18))

original_img = plt.imread(f'/kaggle/input/real-or-drawing/real_or_drawing/train_data/0/0.bmp')
plt.subplot(1, 5, 1)
no_axis_show(original_img, title='original')

gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')



canny_50100 = cv2.Canny(gray_img, 50, 100)
plt.subplot(1, 5, 3)
no_axis_show(canny_50100, title='Canny(50, 100)', cmap='gray')

canny_150200 = cv2.Canny(gray_img, 150, 200)
plt.subplot(1, 5, 4)
no_axis_show(canny_150200, title='Canny(150, 200)', cmap='gray')

canny_250300 = cv2.Canny(gray_img, 250, 300)
plt.subplot(1, 5, 5)
no_axis_show(canny_250300, title='Canny(250, 300)', cmap='gray')
           
李宏毅機器學習作業12-Domain Adversarial Training-代碼李宏毅機器學習作業12-Domain Adversarial Training-代碼

4. 制作dataloader

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

source_transform = transforms.Compose([
    # 轉灰色: Canny 不吃 RGB。
    transforms.Grayscale(),
    # cv2 不吃 skimage.Image,是以轉成np.array後再做cv2.Canny
    transforms.Lambda(lambda x: cv2.Canny(np.array(x), 170, 300)),
    # 重新np.array 轉回 skimage.Image
    transforms.ToPILImage(),
    # 随機水準翻轉 (Augmentation)
    transforms.RandomHorizontalFlip(),
    # 旋轉15度内 (Augmentation),旋轉後空的地方補0
    transforms.RandomRotation(15, fill=(0,)),
    # 最後Tensor供model使用。
    transforms.ToTensor(),
])
target_transform = transforms.Compose([
    # 轉灰階:
    transforms.Grayscale(),
    # 縮放: 因為source data是32x32,我們把target data的28x28放大成32x32。
    transforms.Resize((32, 32)),
    # 随機水準翻轉(Augmentation)
    transforms.RandomHorizontalFlip(),
    # 旋轉15度内 (Augmentation),旋轉後空的地方補0
    transforms.RandomRotation(15, fill=(0,)),
    # 最後Tensor供model使用。
    transforms.ToTensor(),
])

source_dataset = ImageFolder('/kaggle/input/real-or-drawing/real_or_drawing/train_data', transform=source_transform)
target_dataset = ImageFolder('/kaggle/input/real-or-drawing/real_or_drawing/test_data', transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)
           

5. 定義三個模型

class FeatureExtractor(nn.Module):
    '''
    從圖檔中抽取特征
    input [batch_size ,1,32,32]
    output [batch_size ,512]
    '''

    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),  # [batch_size ,64,32,32] (32-3+2*1)/1 + 1
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [batch_size ,64,16,16]

            nn.Conv2d(64, 128, 3, 1, 1),  # [batch_size ,128,16,16]
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [batch_size ,128,8,8]

            nn.Conv2d(128, 256, 3, 1, 1),  # [batch_size ,256,8,8]
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [batch_size ,256,4,4]

            nn.Conv2d(256, 256, 3, 1, 1),  # [batch_size ,256,4,4]
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [batch_size ,256,2,2]

            nn.Conv2d(256, 512, 3, 1, 1),  # [batch_size ,512,2,2]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2)  # [batch_size ,512,1,1]
        )

    def forward(self, x):
        x = self.conv(x).squeeze()  # [batch_size ,256]
        return x

class LabelPredictor(nn.Module):
    '''
    預測圖像是什麼動物
    '''
    def __init__(self):
        super(LabelPredictor, self).__init__()

        self.layer = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),

            nn.Linear(256, 128),
            nn.ReLU(),

            nn.Linear(128, 10),
        )

    def forward(self, h):
        c = self.layer(h)
        return c

class DomainClassifier(nn.Module):
    '''預測時手繪還是真實圖檔'''
    def __init__(self):
        super(DomainClassifier, self).__init__()

        self.layer = nn.Sequential(
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            nn.Linear(64, 16),
            nn.BatchNorm1d(16),
            nn.ReLU(),

            nn.Linear(16, 4),
            nn.BatchNorm1d(4),
            nn.ReLU(),

            nn.Linear(4, 1),
        )

    def forward(self, h):
        y = self.layer(h)
        return y
           
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

feature_extractor = FeatureExtractor().to(device)
label_predictor = LabelPredictor().to(device)
domain_classifier = DomainClassifier().to(device)

class_criterion = nn.CrossEntropyLoss()
domain_criterion = nn.BCEWithLogitsLoss()  # 這個損失函數不要求分類結果sigmoid

optimizer_F = optim.Adam(feature_extractor.parameters())
optimizer_C = optim.Adam(label_predictor.parameters())
optimizer_D = optim.Adam(domain_classifier.parameters())
           

6. 開始訓練

def train_epoch(source_dataloader, target_dataloader, lamb):
    '''
      Args:
        source_dataloader: source data的dataloader
        target_dataloader: target data的dataloader
        lamb: 調控adversarial的loss系數。
    '''

    # D loss: Domain Classifier的loss
    # F loss: Feature Extrator & Label Predictor的loss
    # total_hit: 計算目前對了幾筆 total_num: 目前經過了幾筆
    running_D_loss, running_F_loss = 0.0, 0.0
    total_hit, total_num = 0.0, 0.0

    for i, ((source_data, source_label), (target_data, _)) in enumerate(zip(source_dataloader, target_dataloader)):
        '''
        這裡source_data 隻能取到跟target_data 一樣的數量,超出部分以後再取
        '''
        source_data = source_data.to(device)
        source_label = source_label.to(device)
        target_data = target_data.to(device)

        # 我們把source data和target data混在一起,否則batch_norm可能會算錯 (兩邊的data的mean/var不太一樣)
        mixed_data = torch.cat([source_data, target_data], dim=0)
        domain_label = torch.zeros([source_data.shape[0] + target_data.shape[0], 1]).to(device)
        # 設定source data的label為1
        domain_label[:source_data.shape[0]] = 1

        # Step 1 : 訓練Domain Classifier
        feature = feature_extractor(mixed_data)
        # 因為我們在Step 1不需要訓練Feature Extractor,是以把feature detach避免loss backprop上去。
        domain_logits = domain_classifier(feature.detach())
        loss = domain_criterion(domain_logits, domain_label)
        running_D_loss += loss.item()
        loss.backward()
        optimizer_D.step()

        # Step 2 : 訓練Feature Extractor和Label Predictor
        class_logits = label_predictor(feature[:source_data.shape[0]])
        domain_logits = domain_classifier(feature)
        # loss為原本的class CE - lamb * domain BCE,相減的原因同GAN中的Discriminator中的G loss。
        loss = class_criterion(class_logits, source_label) - lamb * domain_criterion(domain_logits, domain_label)
        running_F_loss += loss.item()
        loss.backward()
        optimizer_F.step()
        optimizer_C.step()

        optimizer_D.zero_grad()
        optimizer_F.zero_grad()
        optimizer_C.zero_grad()

        total_hit += torch.sum(torch.argmax(class_logits, dim=1) == source_label).item()
        total_num += source_data.shape[0]
        print(i, end='\r')

    return running_D_loss / (i + 1), running_F_loss / (i + 1), total_hit / total_num


# 訓練200 epochs
for epoch in range(200):
    train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader, lamb=0.1)

    torch.save(feature_extractor.state_dict(), f'extractor_model.bin')
    torch.save(label_predictor.state_dict(), f'predictor_model.bin')

    print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss,
                                                                                           train_F_loss, train_acc))
           

7. 預測測試資料集

result = []
label_predictor.eval()
feature_extractor.eval()
for i, (test_data, _) in enumerate(test_dataloader):
    test_data = test_data.to(device)

    class_logits = label_predictor(feature_extractor(test_data))

    x = torch.argmax(class_logits, dim=1).cpu().detach().numpy()
    result.append(x)

import pandas as pd
result = np.concatenate(result)

# Generate your submission
df = pd.DataFrame({'id': np.arange(0,len(result)), 'label': result})
df.to_csv('DaNN_submission.csv',index=False)
           

展示前一百幅的結果來看下

labels = iter(df['label'][:100])
def f_names():
    for i in range(100):
        yield '/kaggle/input/real-or-drawing/real_or_drawing/test_data/0/{:05}.bmp'.format(i)

names = iter(f_names())


for j in range(10):
    plt.figure(figsize=(18, 18))
    for i in range(10):
        plt.subplot(1, 10, i + 1)
        name = next(names)
        label = next(labels)
        fig = no_axis_show(plt.imread(name),title=titles[label])
           
李宏毅機器學習作業12-Domain Adversarial Training-代碼李宏毅機器學習作業12-Domain Adversarial Training-代碼
李宏毅機器學習作業12-Domain Adversarial Training-代碼李宏毅機器學習作業12-Domain Adversarial Training-代碼
李宏毅機器學習作業12-Domain Adversarial Training-代碼李宏毅機器學習作業12-Domain Adversarial Training-代碼
李宏毅機器學習作業12-Domain Adversarial Training-代碼李宏毅機器學習作業12-Domain Adversarial Training-代碼