天天看點

在mnist上嘗試triplet loss (mxnet)triplet loss實驗結果關鍵代碼

triplet loss

Triplet Loss損失函數在mnist上做相似度計算

triplet loss的核心包括三個部分

  1. anchor/positive/negative

    代表三個輸入圖,尺寸相同,訓練的目标是令anchor和positive距離最小化,同時anchor和negative距離最大化。以FaceRec為例,anchor和positive一般來自同一個人,而negative屬于不同的另一個人。

  2. shared models

    通用的卷積模型,輸入是單幅圖像,輸出是1維特征向量

  3. loss

    L i = [ ( f ( x i a ) − f ( x i p ) ) 2 − ( f ( x i a ) − f ( x i n ) ) 2 + α ] L = ∑ i N [ m a x ( L i , 0 ) ] L_i = [ (f(x_i^a) - f(x_i^p))^2 - (f(x_i^a) - f(x_i^n))^2 + \alpha] \\\\ L = \sum_i^N [max( L_i, 0)] Li​=[(f(xia​)−f(xip​))2−(f(xia​)−f(xin​))2+α]L=i∑N​[max(Li​,0)]

    其中 α \alpha α是marginal超參

實驗結果

這裡給出一個在mnist集合上嘗試triplet loss的例子,為了減少計算量,實際也隻是采用一部分資料,可以看到效果。

初始時類别分布(一共十個類别,一種顔色代表一個類别)

在mnist上嘗試triplet loss (mxnet)triplet loss實驗結果關鍵代碼

周遊十此後的效果

在mnist上嘗試triplet loss (mxnet)triplet loss實驗結果關鍵代碼

周遊90次之後

在mnist上嘗試triplet loss (mxnet)triplet loss實驗結果關鍵代碼

可以看到同一類逐漸聚集,不同類之間的距離逐漸增大

關鍵代碼

  1. mnist組織成三元組的代碼
class TripletMNIST(gluon.data.Dataset):
    def __init__(self,fortrain,dataset_root="C:/dataset/mnist/",resize=sample_size):
        super(TripletMNIST,self).__init__()
        self.data_pairs = {}
        self.total = 0
        self.resize = resize
        if fortrain:
            ds_root = os.path.join(dataset_root,'train')
        else:
            ds_root = os.path.join(dataset_root,"test")
        for rdir, pdirs, names in os.walk(ds_root):
            for name in names:
                basename,ext = os.path.splitext(name)
                if ext != ".jpg":
                    continue
                fullpath = os.path.join(rdir,name)
                label = fullpath.split('\\')[-2]
                label = int(label)
                if smallset_num > 0 and (label in self.data_pairs) and len(self.data_pairs[label]) >= smallset_num:
                    continue 

                self.data_pairs.setdefault(label,[]).append(fullpath)
                self.total += 1
        self.class_num = len(self.data_pairs.keys())
        return
    
    def __len__(self):
        return self.total
        
    def __getitem__(self,idx):
        rds = np.random.randint(0,10000,size = 5)
        rd_anchor_cls, rd_anchor_idx = rds[0], rds[1]
        rd_anchor_cls = rd_anchor_cls % self.class_num
        rd_anchor_idx = rd_anchor_idx % len(self.data_pairs[rd_anchor_cls])
        
        rd_pos_cls, rd_pos_idx = rd_anchor_cls, rds[2]
        rd_pos_cls = rd_pos_cls % self.class_num
        rd_pos_idx = rd_pos_idx % len(self.data_pairs[rd_pos_cls])
        
        rd_neg_cls, rd_neg_idx = rds[3], rds[4]
        rd_neg_cls = rd_neg_cls % self.class_num
        if rd_neg_cls == rd_pos_cls:
            rd_neg_cls = (rd_neg_cls + 1)%self.class_num
        rd_neg_idx = rd_neg_idx % len(self.data_pairs[rd_neg_cls])
        
        img_anchor = cv2.imread(self.data_pairs[rd_anchor_cls][rd_anchor_idx],1)
        img_pos = cv2.imread(self.data_pairs[rd_pos_cls][rd_pos_idx],1)
        img_neg = cv2.imread(self.data_pairs[rd_neg_cls][rd_neg_idx],1)
        
        img_anchor = cv2.resize(img_anchor, self.resize)
        img_pos = cv2.resize(img_pos, self.resize)
        img_neg = cv2.resize(img_neg, self.resize)
        

        img_anchor = np.float32(img_anchor)/255
        img_pos = np.float32(img_pos)/255
        img_neg = np.float32(img_neg)/255
        
        img_anchor = np.transpose(img_anchor,(2,0,1))
        img_pos = np.transpose(img_pos,(2,0,1))
        img_neg = np.transpose(img_neg,(2,0,1))
        
        return (img_anchor, img_pos, img_neg)
           
  1. 訓練代碼
def train_net(net, train_iter, valid_iter, feat_iter,batch_size, trainer, num_epochs, lr_sch, save_prefix):
    iter_num = 0
    for epoch in range(num_epochs):
        t0 = time.time()
        train_loss = []
        for batch in train_iter:
            iter_num += 1
            trainer.set_learning_rate(lr_sch(iter_num))
            anchor, pos, neg = batch
            #pdb.set_trace()
            X = nd.concat(anchor, pos, neg, dim=0) #combine three inputs along 0-dim to create one batch
            out = X.as_in_context(ctx)
            #print(out.shape)
            with mx.autograd.record(True):
                out = net(out)
                #out = out.as_in_context(mx.cpu())
                out_anchor = out[0:batch_size]
                out_pos = out[batch_size:batch_size*2]
                out_neg = out[batch_size*2 : batch_size*3]
                loss_anchor_pos = (out_anchor - out_pos)**2
                loss_anchor_neg = (out_anchor - out_neg)**2
                #print(loss_anchor_pos.max())
                loss = loss_anchor_pos - loss_anchor_neg
                loss = nd.relu(loss.sum(axis=1) + alpha).mean()
            loss.backward()
            train_loss.append( loss.asnumpy()[0] )
            trainer.step(1)
           # print("\titer {} train loss {}".format(iter_num,np.asarray(train_loss).mean()))
            nd.waitall()
        if (epoch % 10) == 0 and feat_dim == 2:
            show_feat(epoch,net,feat_iter)
        print("epoch {} lr {:>.5} loss {:>.5} cost {:>.3}sec".format(epoch,trainer.learning_rate, \
                                                             np.asarray(train_loss).mean(),time.time() - t0))